3131# ' @param standardize logical, standardize X before fitting
3232# ' @param verbose logical, print progress
3333# ' @param seed int, random seed
34+ # ' @param par logical, use parallel processing. Default FALSE.
35+ # ' @param n.cores int, number of cores for parallel processing. Default NULL (auto-detect).
3436# '
3537# ' @return isotwas_model object containing:
3638# ' \itemize{
4648# ' isoforms sharing more exons are expected to have more similar cis-regulatory
4749# ' effects, as they share more of the same genetic signal.
4850# '
51+ # ' When par=TRUE, cross-validation is parallelized over all (lambda_graph, fold)
52+ # ' combinations, providing up to nlambda_graph * nfolds parallel tasks.
53+ # '
4954# ' @export
5055multivariate_graph_reg <- function (X ,
5156 Y ,
@@ -60,7 +65,9 @@ multivariate_graph_reg <- function(X,
6065 nfolds = 5 ,
6166 standardize = FALSE ,
6267 verbose = FALSE ,
63- seed = 123 ) {
68+ seed = 123 ,
69+ par = FALSE ,
70+ n.cores = NULL ) {
6471
6572 set.seed(seed )
6673 n <- nrow(X )
@@ -155,36 +162,62 @@ multivariate_graph_reg <- function(X,
155162 # Cross-validation over lambda1 and lambda_graph grid
156163 if (verbose ) cat(" Running cross-validation for parameter selection...\n " )
157164
158- cv_results <- matrix (Inf , nrow = length(lambda1_seq ), ncol = length(lambda_graph_seq ))
165+ # Create all (lambda_graph_idx, fold_idx) combinations for parallel execution
166+ cv_tasks <- expand.grid(lg_idx = seq_along(lambda_graph_seq ),
167+ fold_idx = 1 : nfolds )
168+ n_tasks <- nrow(cv_tasks )
159169
160- for (lg_idx in seq_along(lambda_graph_seq )) {
170+ if (verbose ) {
171+ cat(sprintf(" %d lambda_graph values x %d folds = %d parallel tasks\n " ,
172+ length(lambda_graph_seq ), nfolds , n_tasks ))
173+ }
174+
175+ # Worker function for a single (lambda_graph, fold) combination
176+ run_cv_task <- function (task_idx ) {
177+ lg_idx <- cv_tasks $ lg_idx [task_idx ]
178+ fold_idx <- cv_tasks $ fold_idx [task_idx ]
161179 lg <- lambda_graph_seq [lg_idx ]
162- if (verbose ) cat(sprintf(" lambda_graph = %.4f (%d/%d)\n " ,
163- lg , lg_idx , length(lambda_graph_seq )))
164180
165- for (fold_idx in 1 : nfolds ) {
166- train_idx <- cv_folds [[fold_idx ]]
167- test_idx <- setdiff(1 : n , train_idx )
181+ train_idx <- cv_folds [[fold_idx ]]
182+ test_idx <- setdiff(1 : n , train_idx )
168183
169- X_train <- X_scaled [train_idx , , drop = FALSE ]
170- Y_train <- Y_centered [train_idx , , drop = FALSE ]
171- X_test <- X_scaled [test_idx , , drop = FALSE ]
172- Y_test <- Y_centered [test_idx , , drop = FALSE ]
184+ X_train <- X_scaled [train_idx , , drop = FALSE ]
185+ Y_train <- Y_centered [train_idx , , drop = FALSE ]
186+ X_test <- X_scaled [test_idx , , drop = FALSE ]
187+ Y_test <- Y_centered [test_idx , , drop = FALSE ]
173188
174- # Warm start path
175- B_warm <- matrix (0 , p , q )
189+ # Warm start path over lambda1 (must be sequential)
190+ B_warm <- matrix (0 , p , q )
191+ mse_vec <- numeric (length(lambda1_seq ))
176192
177- for (l1_idx in seq_along(lambda1_seq )) {
178- l1 <- lambda1_seq [l1_idx ]
193+ for (l1_idx in seq_along(lambda1_seq )) {
194+ l1 <- lambda1_seq [l1_idx ]
179195
180- B_warm <- .fit_graph_reg(X_train , Y_train , L , l1 , lg , alpha ,
181- B_init = B_warm , max_iter = 200 , tol = 1e-4 )
196+ B_warm <- .fit_graph_reg(X_train , Y_train , L , l1 , lg , alpha ,
197+ B_init = B_warm , max_iter = 200 , tol = 1e-4 )
182198
183- pred <- X_test %*% B_warm
184- mse <- mean((Y_test - pred )^ 2 )
185- cv_results [l1_idx , lg_idx ] <- cv_results [l1_idx , lg_idx ] + mse / nfolds
186- }
199+ pred <- X_test %*% B_warm
200+ mse_vec [l1_idx ] <- mean((Y_test - pred )^ 2 )
187201 }
202+
203+ list (lg_idx = lg_idx , fold_idx = fold_idx , mse_vec = mse_vec )
204+ }
205+
206+ # Run CV tasks (parallel or sequential)
207+ if (par && ! is.null(n.cores ) && n.cores > 1 ) {
208+ if (verbose ) cat(sprintf(" Running in parallel with %d cores...\n " , n.cores ))
209+ cv_task_results <- parallel :: mclapply(1 : n_tasks , run_cv_task ,
210+ mc.cores = min(n.cores , n_tasks ))
211+ } else {
212+ if (verbose ) cat(" Running sequentially...\n " )
213+ cv_task_results <- lapply(1 : n_tasks , run_cv_task )
214+ }
215+
216+ # Aggregate results into cv_results matrix
217+ cv_results <- matrix (0 , nrow = length(lambda1_seq ), ncol = length(lambda_graph_seq ))
218+
219+ for (res in cv_task_results ) {
220+ cv_results [, res $ lg_idx ] <- cv_results [, res $ lg_idx ] + res $ mse_vec / nfolds
188221 }
189222
190223 # Find best parameters
@@ -209,8 +242,7 @@ multivariate_graph_reg <- function(X,
209242 }
210243
211244 # Get CV predictions for R² calculation
212- cv_preds <- matrix (0 , nrow = n , ncol = q )
213- for (fold_idx in 1 : nfolds ) {
245+ run_final_fold <- function (fold_idx ) {
214246 train_idx <- cv_folds [[fold_idx ]]
215247 test_idx <- setdiff(1 : n , train_idx )
216248
@@ -220,7 +252,19 @@ multivariate_graph_reg <- function(X,
220252
221253 B_fold <- .fit_graph_reg(X_train , Y_train , L , best_lambda1 , best_lambda_graph ,
222254 alpha , B_init = NULL , max_iter = 500 , tol = 1e-4 )
223- cv_preds [test_idx , ] <- X_test %*% B_fold
255+ list (test_idx = test_idx , preds = X_test %*% B_fold )
256+ }
257+
258+ if (par && ! is.null(n.cores ) && n.cores > 1 ) {
259+ fold_results <- parallel :: mclapply(1 : nfolds , run_final_fold ,
260+ mc.cores = min(n.cores , nfolds ))
261+ } else {
262+ fold_results <- lapply(1 : nfolds , run_final_fold )
263+ }
264+
265+ cv_preds <- matrix (0 , nrow = n , ncol = q )
266+ for (fr in fold_results ) {
267+ cv_preds [fr $ test_idx , ] <- fr $ preds
224268 }
225269
226270 # Add back means
0 commit comments