Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions R/colocboost_init.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ colocboost_init_model <- function(cb_data,
"learning_rate_init" = learning_rate_init,
"stop_thresh" = stop_thresh,
"ld_jk" = c(),
"jk" = c()
"jk" = c(),
"scaling_factor" = if (!is.null(cb_data$data[[i]]$N)) (cb_data$data[[i]]$N - 1) else 1,
"beta_scaling" = if (!is.null(cb_data$data[[i]]$N)) 1 else 100
)

data_each <- cb_data$data[[i]]
Expand Down Expand Up @@ -375,7 +377,8 @@ inital_residual <- function(Y = NULL, XtY = NULL) {

# - Calculate correlation between X and res
get_correlation <- function(X = NULL, res = NULL, XtY = NULL, N = NULL,
YtY = NULL, XtX = NULL, beta_k = NULL, miss_idx = NULL) {
YtY = NULL, XtX = NULL, beta_k = NULL, miss_idx = NULL,
XtX_beta_cache = NULL) {
if (!is.null(X)) {
corr <- suppressWarnings({
Rfast::correls(res, X)[, "correlation"]
Expand All @@ -399,6 +402,8 @@ get_correlation <- function(X = NULL, res = NULL, XtY = NULL, N = NULL,
}
if (length(XtX) == 1){
var_r <- YtY - 2 * sum(beta_k * XtY) + sum(beta_k^2)
} else if (!is.null(XtX_beta_cache)) {
var_r <- YtY - 2 * sum(beta_k * XtY) + sum(XtX_beta_cache * beta_k)
} else {
var_r <- YtY - 2 * sum(beta_k * XtY) + sum((XtX %*% as.matrix(beta_k)) * beta_k)
}
Expand Down
23 changes: 14 additions & 9 deletions R/colocboost_update.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data) {
)

x_tmp <- cb_data$data[[X_dict]]$X
scaling_factor <- if (!is.null(cb_data$data[[i]]$N)) (cb_data$data[[i]]$N - 1) else 1
scaling_factor <- cb_model[[i]]$scaling_factor
cov_Xtr <- if (!is.null(x_tmp)) {
t(x_tmp) %*% as.matrix(cb_model[[i]]$res) / scaling_factor
} else {
Expand Down Expand Up @@ -104,36 +104,41 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data) {
beta <- cb_model[[i]]$beta
profile_log <- mean((y - x %*% beta)^2) * adj_dep
} else if (!is.null(cb_data$data[[X_dict]]$XtX)) {
scaling_factor <- if (!is.null(cb_data$data[[i]]$N)) cb_data$data[[i]]$N - 1 else 1
beta_scaling <- if (!is.null(cb_data$data[[i]]$N)) 1 else 100
beta_scaling <- cb_model[[i]]$beta_scaling
# - summary statistics
xtx <- cb_data$data[[X_dict]]$XtX
cb_model[[i]]$res <- rep(0, cb_model_para$P)
if (length(cb_data$data[[i]]$variable_miss) != 0) {
beta <- cb_model[[i]]$beta[-cb_data$data[[i]]$variable_miss] / beta_scaling
xty <- cb_data$data[[i]]$XtY[-cb_data$data[[i]]$variable_miss]
if (length(xtx) == 1){
XtX_beta <- beta
cb_model[[i]]$res[-cb_data$data[[i]]$variable_miss] <- xty - scaling_factor * beta
} else {
cb_model[[i]]$res[-cb_data$data[[i]]$variable_miss] <- xty - scaling_factor * xtx %*% beta
XtX_beta <- xtx %*% beta
cb_model[[i]]$res[-cb_data$data[[i]]$variable_miss] <- xty - scaling_factor * XtX_beta
}

} else {
beta <- cb_model[[i]]$beta / beta_scaling
xty <- cb_data$data[[i]]$XtY
if (length(xtx) == 1){
XtX_beta <- beta
cb_model[[i]]$res <- xty - scaling_factor * beta
} else {
cb_model[[i]]$res <- xty - scaling_factor * xtx %*% beta
XtX_beta <- xtx %*% beta
cb_model[[i]]$res <- xty - scaling_factor * XtX_beta
}
}
# - profile-loglikelihood
# - cache XtX %*% beta for reuse in get_correlation (avoids redundant O(P^2) computation)
cb_model[[i]]$XtX_beta_cache <- XtX_beta
# - profile-loglikelihood (reuses cached XtX_beta)
yty <- cb_data$data[[i]]$YtY / scaling_factor
xty <- xty / scaling_factor
if (length(xtx) == 1){
profile_log <- (yty - 2 * sum(beta * xty) + sum(beta^2)) * adj_dep
} else {
profile_log <- (yty - 2 * sum(beta * xty) + sum((xtx %*% as.matrix(beta)) * beta)) * adj_dep
profile_log <- (yty - 2 * sum(beta * xty) + sum(XtX_beta * beta)) * adj_dep
}
}
cb_model[[i]]$profile_loglike_each <- c(cb_model[[i]]$profile_loglike_each, profile_log)
Expand Down Expand Up @@ -277,7 +282,7 @@ boost_obj_last <- function(cb_data, cb_model, cb_model_para) {
)

x_tmp <- cb_data$data[[X_dict]]$X
scaling_factor <- if (!is.null(cb_data$data[[i]]$N)) (cb_data$data[[i]]$N - 1) else 1
scaling_factor <- cb_model[[i]]$scaling_factor
cov_Xtr <- if (!is.null(x_tmp)) {
t(x_tmp) %*% as.matrix(cb_model[[i]]$res) / scaling_factor
} else {
Expand Down
3 changes: 2 additions & 1 deletion R/colocboost_workhorse.R
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ cb_model_update <- function(cb_data, cb_model, cb_model_para) {
N = data_each$N, YtY = data_each$YtY,
XtX = cb_data$data[[X_dict]]$XtX,
beta_k = model_each$beta,
miss_idx = data_each$variable_miss
miss_idx = data_each$variable_miss,
XtX_beta_cache = model_each$XtX_beta_cache
)
cb_model[[i]]$correlation <- tmp
cb_model[[i]]$z <- get_z(tmp, n = data_each$N, model_each$res)
Expand Down
145 changes: 145 additions & 0 deletions inst/benchmark/benchmark_phase1.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env Rscript
# Benchmark: Phase 1 optimization (XtX*beta cache + precomputed constants)
#
# This script compares the optimized code against a simulated "no-cache" baseline
# by manually running the dominant O(P^2) operations the number of times they
# would occur with and without caching.
#
# The optimization eliminates 2 of 3 redundant XtX %*% beta computations per
# iteration per outcome.

library(MASS)

cat("=== ColocBoost Phase 1 Optimization Benchmark ===\n\n")

# ---- Generate test data at different scales ----
run_benchmark <- function(p, L, M, n_ref = 500, seed = 42) {
set.seed(seed)

cat(sprintf("P = %d variants, L = %d outcomes, M = %d iterations, N_ref = %d\n", p, L, M, n_ref))

# Generate LD matrix (P x P)
sigma <- matrix(0, p, p)
for (i in 1:p) {
for (j in 1:p) {
sigma[i, j] <- 0.9^abs(i - j)
}
}
# Ensure positive definite
LD <- sigma

# Generate random beta vector
beta <- rnorm(p) * 0.01

cat(sprintf(" LD matrix size: %.1f MB\n", object.size(LD) / 1024^2))

# ---- Benchmark: XtX %*% beta ----
# Before optimization: 3 * M * L calls to XtX %*% beta
# After optimization: 1 * M * L calls to XtX %*% beta

n_calls_before <- 3 * M * L
n_calls_after <- 1 * M * L

# Time a single XtX %*% beta
t_single <- system.time({
for (rep in 1:100) {
result <- LD %*% beta
}
})[["elapsed"]] / 100

cat(sprintf(" Single XtX %%*%% beta: %.4f seconds\n", t_single))
cat(sprintf(" Before optimization: %d calls = %.2f seconds\n",
n_calls_before, n_calls_before * t_single))
cat(sprintf(" After optimization: %d calls = %.2f seconds\n",
n_calls_after, n_calls_after * t_single))
cat(sprintf(" Speedup on dominant cost: %.1fx\n",
n_calls_before * t_single / (n_calls_after * t_single)))
cat(sprintf(" Time saved: %.2f seconds\n\n",
(n_calls_before - n_calls_after) * t_single))

invisible(list(
p = p, L = L, M = M,
t_single = t_single,
t_before = n_calls_before * t_single,
t_after = n_calls_after * t_single
))
}

# ---- Run end-to-end colocboost benchmark ----
run_colocboost_benchmark <- function(p, L, M, n = 200, seed = 42) {
set.seed(seed)

cat(sprintf("\n--- End-to-end colocboost: P=%d, L=%d, M=%d ---\n", p, L, M))

sigma <- matrix(0, p, p)
for (i in 1:p) {
for (j in 1:p) {
sigma[i, j] <- 0.9^abs(i - j)
}
}
X <- mvrnorm(n, rep(0, p), sigma)
colnames(X) <- paste0("SNP", 1:p)
Y <- matrix(rnorm(n * L), n, L)
# Add signal to SNP5 for all traits
for (l in 1:L) {
Y[, l] <- Y[, l] + X[, 5] * 0.5
}
LD <- cor(X)

# Generate summary statistics
sumstat_list <- list()
for (i in 1:L) {
z <- rep(0, p)
beta_hat <- rep(0, p)
se_hat <- rep(0, p)
for (j in 1:p) {
fit <- summary(lm(Y[, i] ~ X[, j]))$coef
if (nrow(fit) == 2) {
beta_hat[j] <- fit[2, 1]
se_hat[j] <- fit[2, 2]
z[j] <- beta_hat[j] / se_hat[j]
}
}
sumstat_list[[i]] <- data.frame(
beta = beta_hat, sebeta = se_hat, z = z,
n = n, variant = colnames(X)
)
}

# Time full colocboost run
t_full <- system.time({
suppressWarnings(suppressMessages({
result <- colocboost::colocboost(
sumstat = sumstat_list,
LD = LD,
M = M,
output_level = 1
)
}))
})[["elapsed"]]

cat(sprintf(" Total wall time: %.2f seconds\n", t_full))
invisible(t_full)
}

# ---- Scenarios ----

cat("--- Micro-benchmark: XtX %*% beta operation ---\n\n")

results <- list()
results[[1]] <- run_benchmark(p = 1000, L = 2, M = 100)
results[[2]] <- run_benchmark(p = 2000, L = 5, M = 200)
results[[3]] <- run_benchmark(p = 5000, L = 3, M = 300)
results[[4]] <- run_benchmark(p = 5000, L = 10, M = 500)

cat("\n--- Summary Table ---\n")
cat(sprintf("%-8s %-4s %-5s %-12s %-12s %-8s\n",
"P", "L", "M", "Before(s)", "After(s)", "Speedup"))
for (r in results) {
cat(sprintf("%-8d %-4d %-5d %-12.2f %-12.2f %-8.1fx\n",
r$p, r$L, r$M, r$t_before, r$t_after, r$t_before / r$t_after))
}

cat("\n--- End-to-end colocboost timings ---\n")
run_colocboost_benchmark(p = 100, L = 2, M = 50)
run_colocboost_benchmark(p = 100, L = 5, M = 100)
Loading