Skip to content

Commit d71bfcf

Browse files
authored
Merge pull request #132 from StatFunGen/optimize/cache-xtx-beta-phase1
Optimize summary stats mode: cache XtX*beta to eliminate 3x redundant…
2 parents cc6d147 + dac0997 commit d71bfcf

File tree

5 files changed

+586
-12
lines changed

5 files changed

+586
-12
lines changed

R/colocboost_init.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ colocboost_init_model <- function(cb_data,
174174
"learning_rate_init" = learning_rate_init,
175175
"stop_thresh" = stop_thresh,
176176
"ld_jk" = c(),
177-
"jk" = c()
177+
"jk" = c(),
178+
"scaling_factor" = if (!is.null(cb_data$data[[i]]$N)) (cb_data$data[[i]]$N - 1) else 1,
179+
"beta_scaling" = if (!is.null(cb_data$data[[i]]$N)) 1 else 100
178180
)
179181

180182
data_each <- cb_data$data[[i]]
@@ -375,7 +377,8 @@ inital_residual <- function(Y = NULL, XtY = NULL) {
375377

376378
# - Calculate correlation between X and res
377379
get_correlation <- function(X = NULL, res = NULL, XtY = NULL, N = NULL,
378-
YtY = NULL, XtX = NULL, beta_k = NULL, miss_idx = NULL) {
380+
YtY = NULL, XtX = NULL, beta_k = NULL, miss_idx = NULL,
381+
XtX_beta_cache = NULL) {
379382
if (!is.null(X)) {
380383
corr <- suppressWarnings({
381384
Rfast::correls(res, X)[, "correlation"]
@@ -399,6 +402,8 @@ get_correlation <- function(X = NULL, res = NULL, XtY = NULL, N = NULL,
399402
}
400403
if (length(XtX) == 1){
401404
var_r <- YtY - 2 * sum(beta_k * XtY) + sum(beta_k^2)
405+
} else if (!is.null(XtX_beta_cache)) {
406+
var_r <- YtY - 2 * sum(beta_k * XtY) + sum(XtX_beta_cache * beta_k)
402407
} else {
403408
var_r <- YtY - 2 * sum(beta_k * XtY) + sum((XtX %*% as.matrix(beta_k)) * beta_k)
404409
}

R/colocboost_update.R

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data) {
5454
)
5555

5656
x_tmp <- cb_data$data[[X_dict]]$X
57-
scaling_factor <- if (!is.null(cb_data$data[[i]]$N)) (cb_data$data[[i]]$N - 1) else 1
57+
scaling_factor <- cb_model[[i]]$scaling_factor
5858
cov_Xtr <- if (!is.null(x_tmp)) {
5959
t(x_tmp) %*% as.matrix(cb_model[[i]]$res) / scaling_factor
6060
} else {
@@ -104,36 +104,41 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data) {
104104
beta <- cb_model[[i]]$beta
105105
profile_log <- mean((y - x %*% beta)^2) * adj_dep
106106
} else if (!is.null(cb_data$data[[X_dict]]$XtX)) {
107-
scaling_factor <- if (!is.null(cb_data$data[[i]]$N)) cb_data$data[[i]]$N - 1 else 1
108-
beta_scaling <- if (!is.null(cb_data$data[[i]]$N)) 1 else 100
107+
beta_scaling <- cb_model[[i]]$beta_scaling
109108
# - summary statistics
110109
xtx <- cb_data$data[[X_dict]]$XtX
111110
cb_model[[i]]$res <- rep(0, cb_model_para$P)
112111
if (length(cb_data$data[[i]]$variable_miss) != 0) {
113112
beta <- cb_model[[i]]$beta[-cb_data$data[[i]]$variable_miss] / beta_scaling
114113
xty <- cb_data$data[[i]]$XtY[-cb_data$data[[i]]$variable_miss]
115114
if (length(xtx) == 1){
115+
XtX_beta <- beta
116116
cb_model[[i]]$res[-cb_data$data[[i]]$variable_miss] <- xty - scaling_factor * beta
117117
} else {
118-
cb_model[[i]]$res[-cb_data$data[[i]]$variable_miss] <- xty - scaling_factor * xtx %*% beta
118+
XtX_beta <- xtx %*% beta
119+
cb_model[[i]]$res[-cb_data$data[[i]]$variable_miss] <- xty - scaling_factor * XtX_beta
119120
}
120-
121+
121122
} else {
122123
beta <- cb_model[[i]]$beta / beta_scaling
123124
xty <- cb_data$data[[i]]$XtY
124125
if (length(xtx) == 1){
126+
XtX_beta <- beta
125127
cb_model[[i]]$res <- xty - scaling_factor * beta
126128
} else {
127-
cb_model[[i]]$res <- xty - scaling_factor * xtx %*% beta
129+
XtX_beta <- xtx %*% beta
130+
cb_model[[i]]$res <- xty - scaling_factor * XtX_beta
128131
}
129132
}
130-
# - profile-loglikelihood
133+
# - cache XtX %*% beta for reuse in get_correlation (avoids redundant O(P^2) computation)
134+
cb_model[[i]]$XtX_beta_cache <- XtX_beta
135+
# - profile-loglikelihood (reuses cached XtX_beta)
131136
yty <- cb_data$data[[i]]$YtY / scaling_factor
132137
xty <- xty / scaling_factor
133138
if (length(xtx) == 1){
134139
profile_log <- (yty - 2 * sum(beta * xty) + sum(beta^2)) * adj_dep
135140
} else {
136-
profile_log <- (yty - 2 * sum(beta * xty) + sum((xtx %*% as.matrix(beta)) * beta)) * adj_dep
141+
profile_log <- (yty - 2 * sum(beta * xty) + sum(XtX_beta * beta)) * adj_dep
137142
}
138143
}
139144
cb_model[[i]]$profile_loglike_each <- c(cb_model[[i]]$profile_loglike_each, profile_log)
@@ -277,7 +282,7 @@ boost_obj_last <- function(cb_data, cb_model, cb_model_para) {
277282
)
278283

279284
x_tmp <- cb_data$data[[X_dict]]$X
280-
scaling_factor <- if (!is.null(cb_data$data[[i]]$N)) (cb_data$data[[i]]$N - 1) else 1
285+
scaling_factor <- cb_model[[i]]$scaling_factor
281286
cov_Xtr <- if (!is.null(x_tmp)) {
282287
t(x_tmp) %*% as.matrix(cb_model[[i]]$res) / scaling_factor
283288
} else {

R/colocboost_workhorse.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ cb_model_update <- function(cb_data, cb_model, cb_model_para) {
283283
N = data_each$N, YtY = data_each$YtY,
284284
XtX = cb_data$data[[X_dict]]$XtX,
285285
beta_k = model_each$beta,
286-
miss_idx = data_each$variable_miss
286+
miss_idx = data_each$variable_miss,
287+
XtX_beta_cache = model_each$XtX_beta_cache
287288
)
288289
cb_model[[i]]$correlation <- tmp
289290
cb_model[[i]]$z <- get_z(tmp, n = data_each$N, model_each$res)

inst/benchmark/benchmark_phase1.R

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env Rscript
2+
# Benchmark: Phase 1 optimization (XtX*beta cache + precomputed constants)
3+
#
4+
# This script compares the optimized code against a simulated "no-cache" baseline
5+
# by manually running the dominant O(P^2) operations the number of times they
6+
# would occur with and without caching.
7+
#
8+
# The optimization eliminates 2 of 3 redundant XtX %*% beta computations per
9+
# iteration per outcome.
10+
11+
library(MASS)
12+
13+
cat("=== ColocBoost Phase 1 Optimization Benchmark ===\n\n")
14+
15+
# ---- Generate test data at different scales ----
16+
run_benchmark <- function(p, L, M, n_ref = 500, seed = 42) {
17+
set.seed(seed)
18+
19+
cat(sprintf("P = %d variants, L = %d outcomes, M = %d iterations, N_ref = %d\n", p, L, M, n_ref))
20+
21+
# Generate LD matrix (P x P)
22+
sigma <- matrix(0, p, p)
23+
for (i in 1:p) {
24+
for (j in 1:p) {
25+
sigma[i, j] <- 0.9^abs(i - j)
26+
}
27+
}
28+
# Ensure positive definite
29+
LD <- sigma
30+
31+
# Generate random beta vector
32+
beta <- rnorm(p) * 0.01
33+
34+
cat(sprintf(" LD matrix size: %.1f MB\n", object.size(LD) / 1024^2))
35+
36+
# ---- Benchmark: XtX %*% beta ----
37+
# Before optimization: 3 * M * L calls to XtX %*% beta
38+
# After optimization: 1 * M * L calls to XtX %*% beta
39+
40+
n_calls_before <- 3 * M * L
41+
n_calls_after <- 1 * M * L
42+
43+
# Time a single XtX %*% beta
44+
t_single <- system.time({
45+
for (rep in 1:100) {
46+
result <- LD %*% beta
47+
}
48+
})[["elapsed"]] / 100
49+
50+
cat(sprintf(" Single XtX %%*%% beta: %.4f seconds\n", t_single))
51+
cat(sprintf(" Before optimization: %d calls = %.2f seconds\n",
52+
n_calls_before, n_calls_before * t_single))
53+
cat(sprintf(" After optimization: %d calls = %.2f seconds\n",
54+
n_calls_after, n_calls_after * t_single))
55+
cat(sprintf(" Speedup on dominant cost: %.1fx\n",
56+
n_calls_before * t_single / (n_calls_after * t_single)))
57+
cat(sprintf(" Time saved: %.2f seconds\n\n",
58+
(n_calls_before - n_calls_after) * t_single))
59+
60+
invisible(list(
61+
p = p, L = L, M = M,
62+
t_single = t_single,
63+
t_before = n_calls_before * t_single,
64+
t_after = n_calls_after * t_single
65+
))
66+
}
67+
68+
# ---- Run end-to-end colocboost benchmark ----
69+
run_colocboost_benchmark <- function(p, L, M, n = 200, seed = 42) {
70+
set.seed(seed)
71+
72+
cat(sprintf("\n--- End-to-end colocboost: P=%d, L=%d, M=%d ---\n", p, L, M))
73+
74+
sigma <- matrix(0, p, p)
75+
for (i in 1:p) {
76+
for (j in 1:p) {
77+
sigma[i, j] <- 0.9^abs(i - j)
78+
}
79+
}
80+
X <- mvrnorm(n, rep(0, p), sigma)
81+
colnames(X) <- paste0("SNP", 1:p)
82+
Y <- matrix(rnorm(n * L), n, L)
83+
# Add signal to SNP5 for all traits
84+
for (l in 1:L) {
85+
Y[, l] <- Y[, l] + X[, 5] * 0.5
86+
}
87+
LD <- cor(X)
88+
89+
# Generate summary statistics
90+
sumstat_list <- list()
91+
for (i in 1:L) {
92+
z <- rep(0, p)
93+
beta_hat <- rep(0, p)
94+
se_hat <- rep(0, p)
95+
for (j in 1:p) {
96+
fit <- summary(lm(Y[, i] ~ X[, j]))$coef
97+
if (nrow(fit) == 2) {
98+
beta_hat[j] <- fit[2, 1]
99+
se_hat[j] <- fit[2, 2]
100+
z[j] <- beta_hat[j] / se_hat[j]
101+
}
102+
}
103+
sumstat_list[[i]] <- data.frame(
104+
beta = beta_hat, sebeta = se_hat, z = z,
105+
n = n, variant = colnames(X)
106+
)
107+
}
108+
109+
# Time full colocboost run
110+
t_full <- system.time({
111+
suppressWarnings(suppressMessages({
112+
result <- colocboost::colocboost(
113+
sumstat = sumstat_list,
114+
LD = LD,
115+
M = M,
116+
output_level = 1
117+
)
118+
}))
119+
})[["elapsed"]]
120+
121+
cat(sprintf(" Total wall time: %.2f seconds\n", t_full))
122+
invisible(t_full)
123+
}
124+
125+
# ---- Scenarios ----
126+
127+
cat("--- Micro-benchmark: XtX %*% beta operation ---\n\n")
128+
129+
results <- list()
130+
results[[1]] <- run_benchmark(p = 1000, L = 2, M = 100)
131+
results[[2]] <- run_benchmark(p = 2000, L = 5, M = 200)
132+
results[[3]] <- run_benchmark(p = 5000, L = 3, M = 300)
133+
results[[4]] <- run_benchmark(p = 5000, L = 10, M = 500)
134+
135+
cat("\n--- Summary Table ---\n")
136+
cat(sprintf("%-8s %-4s %-5s %-12s %-12s %-8s\n",
137+
"P", "L", "M", "Before(s)", "After(s)", "Speedup"))
138+
for (r in results) {
139+
cat(sprintf("%-8d %-4d %-5d %-12.2f %-12.2f %-8.1fx\n",
140+
r$p, r$L, r$M, r$t_before, r$t_after, r$t_before / r$t_after))
141+
}
142+
143+
cat("\n--- End-to-end colocboost timings ---\n")
144+
run_colocboost_benchmark(p = 100, L = 2, M = 50)
145+
run_colocboost_benchmark(p = 100, L = 5, M = 100)

0 commit comments

Comments
 (0)