Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tests/testthat/test-contrast_rsa_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -750,3 +750,89 @@ test_that("contrast_rsa_model output metrics are internally consistent", {
beta_delta_rel <- run_metric("beta_delta_reliable", reliability = TRUE)
expect_equal(beta_delta_rel, beta_delta_vec * rho_const, tolerance = 1e-6)
})

test_that("beta_delta_reliable uses reliability weights from fold deltas", {
dset <- mock_mvpa_dataset_train(n_samples = 8, n_cond = 4, n_blocks = 2, n_voxels = 1)
colnames(dset$train_data) <- "V1"
mvpa_des <- dset$design

C_custom <- matrix(c(
1, 0,
0, 1,
0, 0,
0, 0
), nrow = 4, ncol = 2, byrow = TRUE,
dimnames = list(levels(mvpa_des$Y), c("Con1", "Con2")))

ms_des <- msreve_design(mvpa_des, C_custom)

model_spec <- contrast_rsa_model(
dataset = dset,
design = ms_des,
regression_type = "pearson",
output_metric = c("beta_delta_reliable"),
calc_reliability = TRUE,
check_collinearity = FALSE
)

cv_spec <- mock_cv_spec_s3(mvpa_des)
sl_data <- dset$train_data
sl_info <- list(center_local_id = 1, center_global_id = 1, radius = 0, n_voxels = 1)

fold_estimates_mock <- array(0, dim = c(4, 1, 2),
dimnames = list(levels(mvpa_des$Y), "V1", c("Fold1", "Fold2")))
fold_estimates_mock[, 1, 1] <- c(1, 0, 0, 0)
fold_estimates_mock[, 1, 2] <- c(2, 0, 0, 0)
mean_estimate_mock <- apply(fold_estimates_mock, c(1, 2), mean)

Delta1 <- t(fold_estimates_mock[, , 1]) %*% C_custom
Delta2 <- t(fold_estimates_mock[, , 2]) %*% C_custom
deltas <- rbind(Delta1[1, ], Delta2[1, ])
mean_delta <- c(0, 0)
M2_delta <- c(0, 0)
valid_folds <- 0
for (i in seq_len(nrow(deltas))) {
delta_fold <- deltas[i, ]
if (anyNA(delta_fold)) next
valid_folds <- valid_folds + 1
delta_diff <- delta_fold - mean_delta
mean_delta <- mean_delta + delta_diff / valid_folds
M2_delta <- M2_delta + delta_diff * (delta_fold - mean_delta)
}
rho_expected <- rep(1, ncol(C_custom))
if (valid_folds > 1) {
var_delta <- M2_delta / (valid_folds - 1)
sigma2_noise_param <- (valid_folds - 1) * var_delta
denom <- var_delta + sigma2_noise_param
rho_expected <- ifelse(denom < 1e-10, 1, sigma2_noise_param / denom)
rho_expected[is.na(rho_expected)] <- 0
} else if (valid_folds == 1) {
rho_expected[M2_delta == 0] <- 1
rho_expected[M2_delta != 0] <- 0
} else {
rho_expected <- rep(0, ncol(C_custom))
}

beta_mock <- c(2, 1)

result <- with_mocked_bindings(
compute_crossvalidated_means_sl = function(...) {
list(mean_estimate = mean_estimate_mock, fold_estimates = fold_estimates_mock)
},
run_cor = function(dvec, obj) {
setNames(beta_mock, colnames(obj$design$model_mat))
},
.package = "rMVPA",
{
train_model.contrast_rsa_model(model_spec, sl_data, sl_info, cv_spec)
}
)

Delta_sl <- t(mean_estimate_mock) %*% C_custom
delta_vc_sl <- Delta_sl[1, ]
beta_delta_expected <- beta_mock * delta_vc_sl

expect_equal(result$beta_delta_reliable,
beta_delta_expected * rho_expected,
tolerance = 1e-6)
})