Skip to content

Commit 057e16d

Browse files
Add parallel support to graph_reg, stacking, and univariate methods
- multivariate_graph_reg: Add par/n.cores params, parallelize CV over (lambda_graph, fold) combinations (up to 25 parallel tasks by default) - multivariate_stacking: Pass parallel/n.cores from compute_isotwas - univariate methods: Use user's par/n.cores instead of hardcoded sequential
1 parent ebd4ef7 commit 057e16d

File tree

2 files changed

+81
-33
lines changed

2 files changed

+81
-33
lines changed

R/compute_isotwas.R

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,9 @@ compute_isotwas <- function(X,
473473
Omega = omega_list$icov[[omega_nlambda]],
474474
nfolds_stack = nfolds,
475475
verbose = FALSE,
476-
seed = seed)
476+
seed = seed,
477+
parallel = par,
478+
n.cores = n.cores)
477479
})
478480
if (!is.null(stacking_mod)) {
479481
all_models <- rlist::list.append(all_models, stacking_mod)
@@ -490,7 +492,9 @@ compute_isotwas <- function(X,
490492
alpha = alpha,
491493
nfolds = nfolds,
492494
verbose = FALSE,
493-
seed = seed)
495+
seed = seed,
496+
par = par,
497+
n.cores = n.cores)
494498
})
495499
if (!is.null(graph_mod)) {
496500
all_models <- rlist::list.append(all_models, graph_mod)
@@ -509,8 +513,8 @@ compute_isotwas <- function(X,
509513
alpha = alpha,
510514
nfolds = nfolds,
511515
verbose = FALSE,
512-
par = FALSE,
513-
n.cores = 1,
516+
par = par,
517+
n.cores = n.cores,
514518
tx_names = tx_names,
515519
seed = seed)
516520

@@ -520,8 +524,8 @@ compute_isotwas <- function(X,
520524
scale = scale,
521525
nfolds = nfolds,
522526
verbose = FALSE,
523-
par = FALSE,
524-
n.cores = 1,
527+
par = par,
528+
n.cores = n.cores,
525529
tx_names = tx_names,
526530
seed = seed)
527531

@@ -532,8 +536,8 @@ compute_isotwas <- function(X,
532536
alpha = alpha,
533537
nfolds = nfolds,
534538
verbose = FALSE,
535-
par = FALSE,
536-
n.cores = 1,
539+
par = par,
540+
n.cores = n.cores,
537541
tx_names = tx_names,
538542
seed = seed)
539543

R/multivariate_graph_reg.R

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#' @param standardize logical, standardize X before fitting
3232
#' @param verbose logical, print progress
3333
#' @param seed int, random seed
34+
#' @param par logical, use parallel processing. Default FALSE.
35+
#' @param n.cores int, number of cores for parallel processing. Default NULL (auto-detect).
3436
#'
3537
#' @return isotwas_model object containing:
3638
#' \itemize{
@@ -46,6 +48,9 @@
4648
#' isoforms sharing more exons are expected to have more similar cis-regulatory
4749
#' effects, as they share more of the same genetic signal.
4850
#'
51+
#' When par=TRUE, cross-validation is parallelized over all (lambda_graph, fold)
52+
#' combinations, providing up to nlambda_graph * nfolds parallel tasks.
53+
#'
4954
#' @export
5055
multivariate_graph_reg <- function(X,
5156
Y,
@@ -60,7 +65,9 @@ multivariate_graph_reg <- function(X,
6065
nfolds = 5,
6166
standardize = FALSE,
6267
verbose = FALSE,
63-
seed = 123) {
68+
seed = 123,
69+
par = FALSE,
70+
n.cores = NULL) {
6471

6572
set.seed(seed)
6673
n <- nrow(X)
@@ -155,36 +162,62 @@ multivariate_graph_reg <- function(X,
155162
# Cross-validation over lambda1 and lambda_graph grid
156163
if (verbose) cat("Running cross-validation for parameter selection...\n")
157164

158-
cv_results <- matrix(Inf, nrow = length(lambda1_seq), ncol = length(lambda_graph_seq))
165+
# Create all (lambda_graph_idx, fold_idx) combinations for parallel execution
166+
cv_tasks <- expand.grid(lg_idx = seq_along(lambda_graph_seq),
167+
fold_idx = 1:nfolds)
168+
n_tasks <- nrow(cv_tasks)
159169

160-
for (lg_idx in seq_along(lambda_graph_seq)) {
170+
if (verbose) {
171+
cat(sprintf(" %d lambda_graph values x %d folds = %d parallel tasks\n",
172+
length(lambda_graph_seq), nfolds, n_tasks))
173+
}
174+
175+
# Worker function for a single (lambda_graph, fold) combination
176+
run_cv_task <- function(task_idx) {
177+
lg_idx <- cv_tasks$lg_idx[task_idx]
178+
fold_idx <- cv_tasks$fold_idx[task_idx]
161179
lg <- lambda_graph_seq[lg_idx]
162-
if (verbose) cat(sprintf(" lambda_graph = %.4f (%d/%d)\n",
163-
lg, lg_idx, length(lambda_graph_seq)))
164180

165-
for (fold_idx in 1:nfolds) {
166-
train_idx <- cv_folds[[fold_idx]]
167-
test_idx <- setdiff(1:n, train_idx)
181+
train_idx <- cv_folds[[fold_idx]]
182+
test_idx <- setdiff(1:n, train_idx)
168183

169-
X_train <- X_scaled[train_idx, , drop = FALSE]
170-
Y_train <- Y_centered[train_idx, , drop = FALSE]
171-
X_test <- X_scaled[test_idx, , drop = FALSE]
172-
Y_test <- Y_centered[test_idx, , drop = FALSE]
184+
X_train <- X_scaled[train_idx, , drop = FALSE]
185+
Y_train <- Y_centered[train_idx, , drop = FALSE]
186+
X_test <- X_scaled[test_idx, , drop = FALSE]
187+
Y_test <- Y_centered[test_idx, , drop = FALSE]
173188

174-
# Warm start path
175-
B_warm <- matrix(0, p, q)
189+
# Warm start path over lambda1 (must be sequential)
190+
B_warm <- matrix(0, p, q)
191+
mse_vec <- numeric(length(lambda1_seq))
176192

177-
for (l1_idx in seq_along(lambda1_seq)) {
178-
l1 <- lambda1_seq[l1_idx]
193+
for (l1_idx in seq_along(lambda1_seq)) {
194+
l1 <- lambda1_seq[l1_idx]
179195

180-
B_warm <- .fit_graph_reg(X_train, Y_train, L, l1, lg, alpha,
181-
B_init = B_warm, max_iter = 200, tol = 1e-4)
196+
B_warm <- .fit_graph_reg(X_train, Y_train, L, l1, lg, alpha,
197+
B_init = B_warm, max_iter = 200, tol = 1e-4)
182198

183-
pred <- X_test %*% B_warm
184-
mse <- mean((Y_test - pred)^2)
185-
cv_results[l1_idx, lg_idx] <- cv_results[l1_idx, lg_idx] + mse / nfolds
186-
}
199+
pred <- X_test %*% B_warm
200+
mse_vec[l1_idx] <- mean((Y_test - pred)^2)
187201
}
202+
203+
list(lg_idx = lg_idx, fold_idx = fold_idx, mse_vec = mse_vec)
204+
}
205+
206+
# Run CV tasks (parallel or sequential)
207+
if (par && !is.null(n.cores) && n.cores > 1) {
208+
if (verbose) cat(sprintf(" Running in parallel with %d cores...\n", n.cores))
209+
cv_task_results <- parallel::mclapply(1:n_tasks, run_cv_task,
210+
mc.cores = min(n.cores, n_tasks))
211+
} else {
212+
if (verbose) cat(" Running sequentially...\n")
213+
cv_task_results <- lapply(1:n_tasks, run_cv_task)
214+
}
215+
216+
# Aggregate results into cv_results matrix
217+
cv_results <- matrix(0, nrow = length(lambda1_seq), ncol = length(lambda_graph_seq))
218+
219+
for (res in cv_task_results) {
220+
cv_results[, res$lg_idx] <- cv_results[, res$lg_idx] + res$mse_vec / nfolds
188221
}
189222

190223
# Find best parameters
@@ -209,8 +242,7 @@ multivariate_graph_reg <- function(X,
209242
}
210243

211244
# Get CV predictions for R² calculation
212-
cv_preds <- matrix(0, nrow = n, ncol = q)
213-
for (fold_idx in 1:nfolds) {
245+
run_final_fold <- function(fold_idx) {
214246
train_idx <- cv_folds[[fold_idx]]
215247
test_idx <- setdiff(1:n, train_idx)
216248

@@ -220,7 +252,19 @@ multivariate_graph_reg <- function(X,
220252

221253
B_fold <- .fit_graph_reg(X_train, Y_train, L, best_lambda1, best_lambda_graph,
222254
alpha, B_init = NULL, max_iter = 500, tol = 1e-4)
223-
cv_preds[test_idx, ] <- X_test %*% B_fold
255+
list(test_idx = test_idx, preds = X_test %*% B_fold)
256+
}
257+
258+
if (par && !is.null(n.cores) && n.cores > 1) {
259+
fold_results <- parallel::mclapply(1:nfolds, run_final_fold,
260+
mc.cores = min(n.cores, nfolds))
261+
} else {
262+
fold_results <- lapply(1:nfolds, run_final_fold)
263+
}
264+
265+
cv_preds <- matrix(0, nrow = n, ncol = q)
266+
for (fr in fold_results) {
267+
cv_preds[fr$test_idx, ] <- fr$preds
224268
}
225269

226270
# Add back means

0 commit comments

Comments
 (0)