diff --git a/NAMESPACE b/NAMESPACE index a87181a7..7b96818a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -16,6 +16,7 @@ S3method(print,summary.bgms) S3method(summary,bgms) export(bgm) export(bgmCompare) +export(bgmCompare2) export(extract_arguments) export(extract_category_thresholds.bgms) export(extract_edge_indicators) diff --git a/R/RcppExports.R b/R/RcppExports.R index d5df61c8..2dce0067 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,8 +1,8 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads) { - .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads) +run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed) { + .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed) } run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) { diff --git a/R/bgmCompare2.R b/R/bgmCompare2.R index 293c6773..7711e79b 100644 --- a/R/bgmCompare2.R +++ b/R/bgmCompare2.R @@ -196,7 +196,7 @@ bgmCompare2 = function( stop("Argument 'seed' must be a non-negative integer or vector of non-negative integers.") } # Force to integer type - seed <- as.integer(seed) + seed = as.integer(seed) dqrng::dqset.seed(seed) } @@ -225,44 +225,14 @@ bgmCompare2 = function( nuts_max_depth = nuts_max_depth, learn_mass_matrix = learn_mass_matrix, projection = projection, - group_membership = sorted_group - 1, ###################################### + group_membership = sorted_group - 1, group_indices = group_indices, interaction_index_matrix = Index, inclusion_probability = model$inclusion_probability_difference, - num_chains = chains, nThreads = cores + num_chains = chains, nThreads = cores, + seed = seed ) - # out = run_gibbs_sampler_for_bgmCompare( - # chain_id = 1, - # observations = observations, - # num_groups = num_groups, - # num_obs_categories = num_obs_categories, - # sufficient_blume_capel = sufficient_blume_capel, - # sufficient_pairwise = sufficient_pairwise, - # num_categories = num_categories[, 1], - # main_alpha = threshold_alpha, - # main_beta = threshold_beta, - # pairwise_scale = interaction_scale, - # difference_scale = difference_scale, - # difference_selection_alpha = beta_bernoulli_alpha, - # difference_selection_beta = beta_bernoulli_beta, - # difference_prior = model$difference_prior, iter = iter, burnin = burnin, - # na_impute = na_impute, missing_data_indices = missing_index, - # is_ordinal_variable = ordinal_variable, - # baseline_category = reference_category, - # difference_selection = difference_selection, - # main_effect_indices = main_effect_indices, - # pairwise_effect_indices = pairwise_effect_indices, - # target_accept = target_accept, - # nuts_max_depth = nuts_max_depth, - # learn_mass_matrix = learn_mass_matrix, - # projection = projection, - # group_membership = sorted_group - 1, ###################################### - # group_indices = group_indices, - # interaction_index_matrix = Index, - # inclusion_probability = model$inclusion_probability_difference) - - # Main output handler in the wrapper function # output = prepare_output_bgmCompare2( # out = out, ... diff --git a/R/function_input_utils.R b/R/function_input_utils.R index 73f216c5..236a3175 100644 --- a/R/function_input_utils.R +++ b/R/function_input_utils.R @@ -737,8 +737,10 @@ check_compare2_model = function( if(difference_inclusion_probability >= 1) stop("The inclusion probability for differences cannot equal or exceed the value one.") - inclusion_probability_difference = inclusion_probability_difference + - difference_probability + inclusion_probability_difference = matrix(difference_probability, + nrow = ncol(x), + ncol = ncol(x)) + } else { if(!inherits(difference_probability, what = "matrix") && !inherits(difference_probability, what = "data.frame")) @@ -765,8 +767,9 @@ check_compare2_model = function( stop("One or more inclusion probabilities for differences are one or larger.") } } else { - inclusion_probability_difference = inclusion_probability_difference + 0.5 - + inclusion_probability_difference = matrix(0.5, + nrow = ncol(x), + ncol = ncol(x)) if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0) stop("The scale parameters of the beta distribution for the differences need to be positive.") if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta)) diff --git a/man/bgm.Rd b/man/bgm.Rd index 7c3d4039..74aaa0ef 100644 --- a/man/bgm.Rd +++ b/man/bgm.Rd @@ -29,7 +29,8 @@ bgm( nuts_max_depth = 10, learn_mass_matrix = FALSE, chains = 4, - cores = parallel::detectCores() + cores = parallel::detectCores(), + seed = NULL ) } \arguments{ diff --git a/man/summarySBM.Rd b/man/summarySBM.Rd deleted file mode 100644 index 01109dda..00000000 --- a/man/summarySBM.Rd +++ /dev/null @@ -1,38 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/posterior_utils.R -\name{summarySBM} -\alias{summarySBM} -\title{Function for summarizing the sampled cluster allocation vectors} -\usage{ -summarySBM(bgm_object, internal_call = FALSE) -} -\arguments{ -\item{bgm_object}{A fit object created by the bgm function.} - -\item{internal_call}{A logical value indicating whether the function is used -within bgms for calculating the posterior probabilities of the number of -clusters or by the user. This argument is always set to FALSE.} -} -\value{ -Returns a list of two elements: \code{components} and \code{allocations}, -containing the posterior probabilities for the number of components (clusters) -and the estimated cluster allocation of the nodes using Dahl's method. -} -\description{ -Th \code{summarySBM} function summarizes the sampled allocation vectors from -each iteration of the Gibbs sampler from the output of the \code{bgm} -function ran with \code{edge_prior = "Stochastic-Block"} and -\code{save = TRUE}. It also estimates the posterior distribution of the -number of clusters. -} -\examples{ -\donttest{ - # fit a model with the SBM prior - bgm_object = bgm( - Wenchuan[, c(1:5)], - edge_prior = "Stochastic-Block", - save = TRUE) - - summarySBM(bgm_object) -} -} diff --git a/src/Makevars b/src/Makevars index 585bbd59..0272e93e 100644 --- a/src/Makevars +++ b/src/Makevars @@ -1,4 +1,4 @@ -CXX_STD = CXX14 +CXX_STD = CXX20 ## Pull in the include-paths for RcppParallel PKG_CPPFLAGS = \ @@ -11,4 +11,4 @@ PKG_CPPFLAGS = \ PKG_LIBS = \ $(shell "${R_HOME}/bin/Rscript" -e "cat(RcppParallel::LdFlags())") \ -Wl,-rpath,$(shell "${R_HOME}/bin/Rscript" -e "cat(system.file('lib',package='RcppParallel'))") \ - $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) + $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) \ No newline at end of file diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 6d12065c..832f0b7a 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -12,44 +12,45 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // run_bgmCompare_parallel -Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, const int num_groups, const std::vector& num_obs_categories, const std::vector& sufficient_blume_capel, const std::vector& sufficient_pairwise, const arma::ivec& num_categories, const double main_alpha, const double main_beta, const double pairwise_scale, const double difference_scale, const double difference_selection_alpha, const double difference_selection_beta, const std::string& difference_prior, const int iter, const int burnin, const bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, const double target_accept, const int nuts_max_depth, const bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, const int num_chains, const int nThreads); -RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP) { +Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector& num_obs_categories, const std::vector& sufficient_blume_capel, const std::vector& sufficient_pairwise, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed); +RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const arma::imat& >::type observations(observationsSEXP); - Rcpp::traits::input_parameter< const int >::type num_groups(num_groupsSEXP); + Rcpp::traits::input_parameter< int >::type num_groups(num_groupsSEXP); Rcpp::traits::input_parameter< const std::vector& >::type num_obs_categories(num_obs_categoriesSEXP); Rcpp::traits::input_parameter< const std::vector& >::type sufficient_blume_capel(sufficient_blume_capelSEXP); Rcpp::traits::input_parameter< const std::vector& >::type sufficient_pairwise(sufficient_pairwiseSEXP); Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP); - Rcpp::traits::input_parameter< const double >::type main_alpha(main_alphaSEXP); - Rcpp::traits::input_parameter< const double >::type main_beta(main_betaSEXP); - Rcpp::traits::input_parameter< const double >::type pairwise_scale(pairwise_scaleSEXP); - Rcpp::traits::input_parameter< const double >::type difference_scale(difference_scaleSEXP); - Rcpp::traits::input_parameter< const double >::type difference_selection_alpha(difference_selection_alphaSEXP); - Rcpp::traits::input_parameter< const double >::type difference_selection_beta(difference_selection_betaSEXP); + Rcpp::traits::input_parameter< double >::type main_alpha(main_alphaSEXP); + Rcpp::traits::input_parameter< double >::type main_beta(main_betaSEXP); + Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP); + Rcpp::traits::input_parameter< double >::type difference_scale(difference_scaleSEXP); + Rcpp::traits::input_parameter< double >::type difference_selection_alpha(difference_selection_alphaSEXP); + Rcpp::traits::input_parameter< double >::type difference_selection_beta(difference_selection_betaSEXP); Rcpp::traits::input_parameter< const std::string& >::type difference_prior(difference_priorSEXP); - Rcpp::traits::input_parameter< const int >::type iter(iterSEXP); - Rcpp::traits::input_parameter< const int >::type burnin(burninSEXP); - Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); + Rcpp::traits::input_parameter< int >::type iter(iterSEXP); + Rcpp::traits::input_parameter< int >::type burnin(burninSEXP); + Rcpp::traits::input_parameter< bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type missing_data_indices(missing_data_indicesSEXP); Rcpp::traits::input_parameter< const arma::uvec& >::type is_ordinal_variable(is_ordinal_variableSEXP); Rcpp::traits::input_parameter< const arma::ivec& >::type baseline_category(baseline_categorySEXP); - Rcpp::traits::input_parameter< const bool >::type difference_selection(difference_selectionSEXP); + Rcpp::traits::input_parameter< bool >::type difference_selection(difference_selectionSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type main_effect_indices(main_effect_indicesSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type pairwise_effect_indices(pairwise_effect_indicesSEXP); - Rcpp::traits::input_parameter< const double >::type target_accept(target_acceptSEXP); - Rcpp::traits::input_parameter< const int >::type nuts_max_depth(nuts_max_depthSEXP); - Rcpp::traits::input_parameter< const bool >::type learn_mass_matrix(learn_mass_matrixSEXP); + Rcpp::traits::input_parameter< double >::type target_accept(target_acceptSEXP); + Rcpp::traits::input_parameter< int >::type nuts_max_depth(nuts_max_depthSEXP); + Rcpp::traits::input_parameter< bool >::type learn_mass_matrix(learn_mass_matrixSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type projection(projectionSEXP); Rcpp::traits::input_parameter< const arma::ivec& >::type group_membership(group_membershipSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type group_indices(group_indicesSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type interaction_index_matrix(interaction_index_matrixSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP); - Rcpp::traits::input_parameter< const int >::type num_chains(num_chainsSEXP); - Rcpp::traits::input_parameter< const int >::type nThreads(nThreadsSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads)); + Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP); + Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP); + Rcpp::traits::input_parameter< int >::type seed(seedSEXP); + rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed)); return rcpp_result_gen; END_RCPP } @@ -188,7 +189,7 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 32}, + {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 33}, {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 31}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8}, diff --git a/src/bgmCompare_helper.h b/src/bgmCompare_helper.h index 0593f36d..65f5caa6 100644 --- a/src/bgmCompare_helper.h +++ b/src/bgmCompare_helper.h @@ -1,6 +1,7 @@ #pragma once #include +#include "rng_utils.h" @@ -65,4 +66,40 @@ arma::vec inv_mass_active( const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, const bool& selection -); \ No newline at end of file +); + +inline void initialise_graph( + arma::imat& indicator, + arma::mat& main, + arma::mat& pairwise, + const arma::imat& main_indices, + const arma::imat& pairwise_indices, + const arma::mat& incl_prob, + SafeRNG& rng +) { + int V = indicator.n_rows; + int G = main.n_cols; + for (int i = 0; i < V-1; ++i) { + for (int j = i+1; j < V; ++j) { + double p = incl_prob(i,j); + int draw = (runif(rng) < p) ? 1 : 0; + indicator(i,j) = indicator(j,i) = draw; + if (!draw) { + int row = pairwise_indices(i, j); + pairwise.row(row).cols(1, G-1) = 0.0; + } + } + } + for(int i = 0; i < V; i++) { + double p = incl_prob(i,i); + int draw = (runif(rng) < p) ? 1 : 0; + indicator(i,i) = draw; + if(!draw) { + int start = main_indices(i,0); + int end = main_indices(i,1); + for(int row = start; row < end; row++) { + main.row(row).cols(1, G-1) = 0.0; + } + } + } +}; \ No newline at end of file diff --git a/src/bgmCompare_logp_and_grad.cpp b/src/bgmCompare_logp_and_grad.cpp index 1f577cfa..c63af2ed 100644 --- a/src/bgmCompare_logp_and_grad.cpp +++ b/src/bgmCompare_logp_and_grad.cpp @@ -16,9 +16,9 @@ double log_pseudoposterior( const arma::imat& observations, const arma::imat& group_indices, const arma::ivec& num_categories, - const Rcpp::List& num_obs_categories_group, - const Rcpp::List& sufficient_blume_capel_group, - const Rcpp::List& sufficient_pairwise_group, + const std::vector& num_obs_categories_group, + const std::vector& sufficient_blume_capel_group, + const std::vector& sufficient_pairwise_group, const int num_groups, const arma::imat& inclusion_indicator, const arma::uvec& is_ordinal_variable, @@ -34,8 +34,8 @@ double log_pseudoposterior( // --- per group --- for (int group = 0; group < num_groups; ++group) { - const arma::imat num_obs_categories = Rcpp::as(num_obs_categories_group[group]); - const arma::imat sufficient_blume_capel = Rcpp::as(sufficient_blume_capel_group[group]); + const arma::imat num_obs_categories = num_obs_categories_group[group]; + const arma::imat sufficient_blume_capel = sufficient_blume_capel_group[group]; arma::mat main_group(num_variables, max_num_categories, arma::fill::zeros); arma::mat pairwise_group(num_variables, num_variables, arma::fill::zeros); @@ -81,7 +81,7 @@ double log_pseudoposterior( const int r0 = group_indices(group, 0); const int r1 = group_indices(group, 1); const arma::mat obs = arma::conv_to::from(observations.rows(r0, r1)); - const arma::mat sufficient_pairwise = Rcpp::as(sufficient_pairwise_group[group]); + const arma::mat sufficient_pairwise = sufficient_pairwise_group[group]; log_pp += arma::accu(pairwise_group % sufficient_pairwise); // trace(X' * W * X) = sum(W %*% (X'X)) @@ -170,9 +170,9 @@ arma::vec gradient( const arma::imat& observations, const arma::imat& group_indices, const arma::ivec& num_categories, - const Rcpp::List& num_obs_categories_group, - const Rcpp::List& sufficient_blume_capel_group, - const Rcpp::List& sufficient_pairwise_group, + const std::vector& num_obs_categories_group, + const std::vector& sufficient_blume_capel_group, + const std::vector& sufficient_pairwise_group, const int num_groups, const arma::imat& inclusion_indicator, const arma::uvec& is_ordinal_variable, @@ -214,10 +214,8 @@ arma::vec gradient( // ------------------------------- for (int g = 0; g < num_groups; ++g) { // list access - SEXP s1 = num_obs_categories_group[g]; - SEXP s2 = sufficient_blume_capel_group[g]; - arma::imat num_obs_categories = Rcpp::as(s1); - arma::imat sufficient_blume_capel = Rcpp::as(s2); + arma::imat num_obs_categories = num_obs_categories_group[g]; + arma::imat sufficient_blume_capel = sufficient_blume_capel_group[g]; // Main effects for (int v = 0; v < num_variables; ++v) { @@ -260,9 +258,7 @@ arma::vec gradient( } // Pairwise (observed) - SEXP s3 = sufficient_pairwise_group[g]; - arma::mat sufficient_pairwise = Rcpp::as(s3); - + arma::mat sufficient_pairwise = sufficient_pairwise_group[g]; for (int v1 = 0; v1 < num_variables - 1; ++v1) { for (int v2 = v1 + 1; v2 < num_variables; ++v2) { const int row = pairwise_effect_indices(v1, v2); diff --git a/src/bgmCompare_logp_and_grad.h b/src/bgmCompare_logp_and_grad.h index 46949677..7b679fb5 100644 --- a/src/bgmCompare_logp_and_grad.h +++ b/src/bgmCompare_logp_and_grad.h @@ -13,9 +13,9 @@ double log_pseudoposterior( const arma::imat& observations, const arma::imat& group_indices, const arma::ivec& num_categories, - const Rcpp::List& num_obs_categories_group, - const Rcpp::List& sufficient_blume_capel_group, - const Rcpp::List& sufficient_pairwise_group, + const std::vector& num_obs_categories, + const std::vector& sufficient_blume_capel, + const std::vector& sufficient_pairwise, const int num_groups, const arma::imat& inclusion_indicator, const arma::uvec& is_ordinal_variable, @@ -35,9 +35,9 @@ arma::vec gradient( const arma::imat& observations, const arma::imat& group_indices, const arma::ivec& num_categories, - const Rcpp::List& num_obs_categories_group, - const Rcpp::List& sufficient_blume_capel_group, - const Rcpp::List& sufficient_pairwise_group, + const std::vector& num_obs_categories, + const std::vector& sufficient_blume_capel, + const std::vector& sufficient_pairwise, const int num_groups, const arma::imat& inclusion_indicator, const arma::uvec& is_ordinal_variable, diff --git a/src/bgmCompare_parallel.cpp b/src/bgmCompare_parallel.cpp index 08055898..11bcbbf4 100644 --- a/src/bgmCompare_parallel.cpp +++ b/src/bgmCompare_parallel.cpp @@ -13,11 +13,11 @@ using namespace RcppParallel; // ----------------------------------------------------------------------------- // Result struct // ----------------------------------------------------------------------------- -struct ChainResult { +struct ChainResultCompare { bool error; std::string error_msg; int chain_id; - Rcpp::List result; + SamplerOutput result; }; // ----------------------------------------------------------------------------- @@ -27,9 +27,9 @@ struct GibbsCompareChainRunner : public Worker { // inputs const arma::imat& observations; const int num_groups; - const std::vector& num_obs_categories_cpp_master; - const std::vector& sufficient_blume_capel_cpp_master; - const std::vector& sufficient_pairwise_cpp_master; + const std::vector& num_obs_categories_master; + const std::vector& sufficient_blume_capel_master; + const std::vector& sufficient_pairwise_master; const arma::ivec& num_categories; const double main_alpha; const double main_beta; @@ -60,47 +60,47 @@ struct GibbsCompareChainRunner : public Worker { const std::vector& chain_rngs; // output - std::vector& results; + std::vector& results; GibbsCompareChainRunner( const arma::imat& observations, - const int num_groups, - const std::vector& num_obs_categories_cpp_master, - const std::vector& sufficient_blume_capel_cpp_master, - const std::vector& sufficient_pairwise_cpp_master, + int num_groups, + const std::vector& num_obs_categories_master, + const std::vector& sufficient_blume_capel_master, + const std::vector& sufficient_pairwise_master, const arma::ivec& num_categories, - const double main_alpha, - const double main_beta, - const double pairwise_scale, - const double difference_scale, - const double difference_selection_alpha, - const double difference_selection_beta, + double main_alpha, + double main_beta, + double pairwise_scale, + double difference_scale, + double difference_selection_alpha, + double difference_selection_beta, const std::string& difference_prior, - const int iter, - const int burnin, - const bool na_impute, + int iter, + int burnin, + bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, - const bool difference_selection, + bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, - const double target_accept, - const int nuts_max_depth, - const bool learn_mass_matrix, + double target_accept, + int nuts_max_depth, + bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability_master, const std::vector& chain_rngs, - std::vector& results + std::vector& results ) : observations(observations), num_groups(num_groups), - num_obs_categories_cpp_master(num_obs_categories_cpp_master), - sufficient_blume_capel_cpp_master(sufficient_blume_capel_cpp_master), - sufficient_pairwise_cpp_master(sufficient_pairwise_cpp_master), + num_obs_categories_master(num_obs_categories_master), + sufficient_blume_capel_master(sufficient_blume_capel_master), + sufficient_pairwise_master(sufficient_pairwise_master), num_categories(num_categories), main_alpha(main_alpha), main_beta(main_beta), @@ -132,7 +132,7 @@ struct GibbsCompareChainRunner : public Worker { void operator()(std::size_t begin, std::size_t end) { for (std::size_t i = begin; i < end; ++i) { - ChainResult out; + ChainResultCompare out; out.chain_id = static_cast(i + 1); out.error = false; @@ -141,19 +141,14 @@ struct GibbsCompareChainRunner : public Worker { SafeRNG rng(chain_rngs[i]); // make per-chain copies - std::vector num_obs_categories_cpp = num_obs_categories_cpp_master; - std::vector sufficient_blume_capel_cpp = sufficient_blume_capel_cpp_master; - std::vector sufficient_pairwise_cpp = sufficient_pairwise_cpp_master; + std::vector num_obs_categories = num_obs_categories_master; + std::vector sufficient_blume_capel = sufficient_blume_capel_master; + std::vector sufficient_pairwise = sufficient_pairwise_master; arma::mat inclusion_probability = inclusion_probability_master; arma::imat observations_copy = observations; - // convert vectors -> Rcpp::List - Rcpp::List num_obs_categories(num_obs_categories_cpp.begin(), num_obs_categories_cpp.end()); - Rcpp::List sufficient_blume_capel(sufficient_blume_capel_cpp.begin(), sufficient_blume_capel_cpp.end()); - Rcpp::List sufficient_pairwise(sufficient_pairwise_cpp.begin(), sufficient_pairwise_cpp.end()); - - // run sampler - Rcpp::List result = run_gibbs_sampler_for_bgmCompare( + // run sampler (pure C++) + SamplerOutput result = run_gibbs_sampler_for_bgmCompare( out.chain_id, observations_copy, num_groups, @@ -185,7 +180,7 @@ struct GibbsCompareChainRunner : public Worker { group_indices, interaction_index_matrix, inclusion_probability, - rng // <- pass generator + rng ); out.result = result; @@ -210,44 +205,45 @@ struct GibbsCompareChainRunner : public Worker { // [[Rcpp::export]] Rcpp::List run_bgmCompare_parallel( const arma::imat& observations, - const int num_groups, + int num_groups, const std::vector& num_obs_categories, const std::vector& sufficient_blume_capel, const std::vector& sufficient_pairwise, const arma::ivec& num_categories, - const double main_alpha, - const double main_beta, - const double pairwise_scale, - const double difference_scale, - const double difference_selection_alpha, - const double difference_selection_beta, + double main_alpha, + double main_beta, + double pairwise_scale, + double difference_scale, + double difference_selection_alpha, + double difference_selection_beta, const std::string& difference_prior, - const int iter, - const int burnin, - const bool na_impute, + int iter, + int burnin, + bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, - const bool difference_selection, + bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, - const double target_accept, - const int nuts_max_depth, - const bool learn_mass_matrix, + double target_accept, + int nuts_max_depth, + bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, - const int num_chains, - const int nThreads + int num_chains, + int nThreads, + int seed ) { - std::vector results(num_chains); + std::vector results(num_chains); // per-chain seeds std::vector chain_rngs(num_chains); for (int c = 0; c < num_chains; ++c) { - chain_rngs[c] = SafeRNG(/*seed +*/ c); // TODO: this needs a seed passed by the user! + chain_rngs[c] = SafeRNG(seed + c); } @@ -268,6 +264,7 @@ Rcpp::List run_bgmCompare_parallel( parallelFor(0, num_chains, worker); } + // wrap results back into Rcpp::List Rcpp::List output(num_chains); for (int i = 0; i < num_chains; ++i) { if (results[i].error) { @@ -276,7 +273,19 @@ Rcpp::List run_bgmCompare_parallel( Rcpp::Named("chain_id") = results[i].chain_id ); } else { - output[i] = results[i].result; + const auto& r = results[i].result; + Rcpp::List chain_out = Rcpp::List::create( + Rcpp::Named("main_samples") = r.main_samples, + Rcpp::Named("pairwise_samples") = r.pairwise_samples, + Rcpp::Named("treedepth__") = r.treedepth_samples, + Rcpp::Named("divergent__") = r.divergent_samples, + Rcpp::Named("energy__") = r.energy_samples, + Rcpp::Named("chain_id") = r.chain_id + ); + if (r.has_indicator) { + chain_out["indicator_samples"] = r.indicator_samples; + } + output[i] = chain_out; } } diff --git a/src/bgmCompare_sampler.cpp b/src/bgmCompare_sampler.cpp index 66ca965e..2653b71b 100644 --- a/src/bgmCompare_sampler.cpp +++ b/src/bgmCompare_sampler.cpp @@ -1,5 +1,4 @@ #include -#include #include "bgmCompare_helper.h" #include "bgmCompare_logp_and_grad.h" #include "bgmCompare_sampler.h" @@ -12,6 +11,7 @@ #include "mcmc_utils.h" #include "print_mutex.h" #include "rng_utils.h" +#include "sampler_output.h" using namespace Rcpp; @@ -45,7 +45,7 @@ using namespace Rcpp; * - `sufficient_blume_capel`: Updated sufficient statistics for Blume-Capel variables. * - `residual_matrix`: Updated residual effects matrix. */ -List impute_missing_data_for_graphical_model( +void impute_missing_data_for_graphical_model( const arma::mat& main_effects, const arma::mat& pairwise_effects, const arma::imat& main_effect_indices, @@ -56,9 +56,9 @@ List impute_missing_data_for_graphical_model( const int num_groups, const arma::ivec& group_membership, const arma::imat& group_indices, - List& num_obs_categories, - List& sufficient_blume_capel, - List& sufficient_pairwise, + std::vector& num_obs_categories, + std::vector& sufficient_blume_capel, + std::vector& sufficient_pairwise, const arma::imat& num_categories, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, @@ -167,15 +167,13 @@ List impute_missing_data_for_graphical_model( } } - return List::create(Named("observations") = observations, - Named("num_obs_categories") = num_obs_categories, - Named("sufficient_blume_capel") = sufficient_blume_capel); + return; } /** - * Function: find_reasonable_initial_step_size + * Function: find_reasonable_initial_step_size_cmp * * Heuristically finds a reasonable initial step size for leapfrog-based MCMC algorithms * (such as HMC and NUTS), following the procedure described in: @@ -213,7 +211,7 @@ List impute_missing_data_for_graphical_model( * - This function is suitable for both NUTS and standard HMC algorithms. * - It is typically called once before warm-up/adaptation. */ -double find_reasonable_initial_step_size( +double find_reasonable_initial_step_size_cmp( arma::mat& main_effects, arma::mat& pairwise_effects, const arma::imat& main_effect_indices, @@ -224,9 +222,9 @@ double find_reasonable_initial_step_size( const arma::imat& observations, const int num_groups, const arma::imat& group_indices, - const Rcpp::List& num_obs_categories, - const Rcpp::List& sufficient_blume_capel, - const Rcpp::List& sufficient_pairwise, + const std::vector& num_obs_categories, + const std::vector& sufficient_blume_capel, + const std::vector& sufficient_pairwise, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const double pairwise_scale, @@ -336,9 +334,9 @@ SamplerResult update_parameters_with_nuts( const arma::imat& observations, const int num_groups, const arma::imat& group_indices, - const Rcpp::List& num_obs_categories, - const Rcpp::List& sufficient_blume_capel, - const Rcpp::List& sufficient_pairwise, + const std::vector& num_obs_categories, + const std::vector& sufficient_blume_capel, + const std::vector& sufficient_pairwise, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const double pairwise_scale, @@ -585,12 +583,12 @@ SamplerResult update_parameters_with_nuts( * - sqrt_inv_fisher_main * - sqrt_inv_fisher_pairwise */ -void gibbs_update_step_for_graphical_model_parameters ( +void gibbs_update_step_for_graphical_model_parameters_cmp ( const arma::imat& observations, const arma::ivec& num_categories, const double pairwise_scale, - const Rcpp::List& num_obs_categories, - const Rcpp::List& sufficient_blume_capel, + const std::vector& num_obs_categories, + const std::vector& sufficient_blume_capel, const double main_alpha, const double main_beta, arma::imat& inclusion_indicator, @@ -600,7 +598,7 @@ void gibbs_update_step_for_graphical_model_parameters ( const arma::ivec& baseline_category, const int iteration, const arma::imat& pairwise_effect_indices, - const Rcpp::List& sufficient_pairwise, + const std::vector& sufficient_pairwise, const int nuts_max_depth, HMCAdaptationController& adapt, const bool learn_mass_matrix, @@ -613,9 +611,18 @@ void gibbs_update_step_for_graphical_model_parameters ( const int num_groups, const arma::imat group_indices, double difference_scale, - SafeRNG& rng + SafeRNG& rng, + arma::mat& inclusion_probability ) { + // Step 0: Initialise random graph structure when edge_selection = TRUE + if (schedule.selection_enabled(iteration) && iteration == schedule.stage3c_start) { + initialise_graph( + inclusion_indicator, main_effects, pairwise_effects, main_effect_indices, + pairwise_effect_indices, inclusion_probability, rng + ); + } + SamplerResult result = update_parameters_with_nuts( main_effects, pairwise_effects, main_effect_indices, pairwise_effect_indices, inclusion_indicator, projection, num_categories, @@ -638,13 +645,15 @@ void gibbs_update_step_for_graphical_model_parameters ( -Rcpp::List run_gibbs_sampler_for_bgmCompare( +SamplerOutput run_gibbs_sampler_for_bgmCompare( int chain_id, arma::imat observations, const int num_groups, - Rcpp::List num_obs_categories, - Rcpp::List sufficient_blume_capel, - Rcpp::List sufficient_pairwise, + // TODO for Maarten: this will crash horribly when imputing in parallel + // each thread needs an individual copy of the three objects below + std::vector& num_obs_categories, + std::vector& sufficient_blume_capel, + std::vector& sufficient_pairwise, const arma::ivec& num_categories, const double main_alpha, const double main_beta, @@ -652,7 +661,7 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( const double difference_scale,//new const double difference_selection_alpha,//new const double difference_selection_beta,//new - const std::string difference_prior,//new + const std::string& difference_prior,//new const int iter, const int burnin, const bool na_impute, @@ -660,12 +669,12 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const bool difference_selection,//new - const arma::imat main_effect_indices, - const arma::imat pairwise_effect_indices, + const arma::imat& main_effect_indices, + const arma::imat& pairwise_effect_indices, const double target_accept, const int nuts_max_depth, const bool learn_mass_matrix, - const arma::mat projection,//new + const arma::mat& projection,//new const arma::ivec& group_membership,//new const arma::imat& group_indices,//new const arma::imat& interaction_index_matrix,//new @@ -705,7 +714,8 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( // --- Optional HMC/NUTS warmup stage double initial_step_size = 1.0; - initial_step_size = find_reasonable_initial_step_size( + + initial_step_size = find_reasonable_initial_step_size_cmp( main_effects, pairwise_effects, main_effect_indices, pairwise_effect_indices, inclusion_indicator, projection, num_categories, observations, num_groups, group_indices, num_obs_categories, @@ -728,7 +738,8 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( for (int iteration = 0; iteration < total_iter; iteration++) { if (iteration % print_every == 0) { tbb::mutex::scoped_lock lock(get_print_mutex()); - Rcpp::Rcout + //Rcpp::Rcout + std::cout << "[bgm] chain " << chain_id << " iteration " << iteration << " / " << total_iter @@ -743,6 +754,7 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( // Optional imputation if (na_impute) { + impute_missing_data_for_graphical_model ( main_effects, pairwise_effects, main_effect_indices, pairwise_effect_indices, inclusion_indicator, projection, @@ -754,7 +766,7 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( } // Main Gibbs update step for parameters - gibbs_update_step_for_graphical_model_parameters ( + gibbs_update_step_for_graphical_model_parameters_cmp ( observations, num_categories, pairwise_scale, num_obs_categories, sufficient_blume_capel, main_alpha, main_beta, inclusion_indicator, pairwise_effects, main_effects, is_ordinal_variable, baseline_category, @@ -762,7 +774,7 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( adapt_joint, learn_mass_matrix, warmup_schedule, treedepth_samples, divergent_samples, energy_samples, main_effect_indices, projection, num_groups, group_indices, difference_scale,//new line of args - rng + rng, inclusion_probability ); // --- Update difference probabilities under the prior (if difference selection is active) @@ -779,8 +791,9 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( for(int i = 0; i < num_variables; i++) { sumG += inclusion_indicator(i, i); } - double prob = rbeta(rng, difference_selection_alpha + sumG, - difference_selection_beta + num_pair + num_variables - sumG); + double prob = rbeta(rng, + difference_selection_alpha + sumG, + difference_selection_beta + num_pair + num_variables - sumG); std::fill(inclusion_probability.begin(), inclusion_probability.end(), prob); } } @@ -818,18 +831,18 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare( } } - Rcpp::List out; - out["main_samples"] = main_effect_samples; - out["pairwise_samples"] = pairwise_effect_samples; - - out["treedepth__"] = treedepth_samples; - out["divergent__"] = divergent_samples; - out["energy__"] = energy_samples; - + SamplerOutput out; + out.chain_id = chain_id; + out.main_samples = main_effect_samples; + out.pairwise_samples = pairwise_effect_samples; + out.treedepth_samples = treedepth_samples; + out.divergent_samples = divergent_samples; + out.energy_samples = energy_samples; + out.has_indicator = difference_selection; if (difference_selection) { - out["indicator_samples"] = indicator_samples; + out.indicator_samples = indicator_samples; + } else { + out.indicator_samples = arma::imat(); } - - out["chain_id"] = chain_id; return out; } \ No newline at end of file diff --git a/src/bgmCompare_sampler.h b/src/bgmCompare_sampler.h index 9ddc7d3b..baeb5dd8 100644 --- a/src/bgmCompare_sampler.h +++ b/src/bgmCompare_sampler.h @@ -1,39 +1,40 @@ #pragma once #include +#include "sampler_output.h" struct SafeRNG; -Rcpp::List run_gibbs_sampler_for_bgmCompare( +SamplerOutput run_gibbs_sampler_for_bgmCompare( int chain_id, arma::imat observations, const int num_groups, - Rcpp::List num_obs_categories, - Rcpp::List sufficient_blume_capel, - Rcpp::List sufficient_pairwise, + std::vector& num_obs_categories, + std::vector& sufficient_blume_capel, + std::vector& sufficient_pairwise, const arma::ivec& num_categories, const double main_alpha, const double main_beta, const double pairwise_scale, - const double difference_scale,//new - const double difference_selection_alpha,//new - const double difference_selection_beta,//new - const std::string difference_prior,//new + const double difference_scale, + const double difference_selection_alpha, + const double difference_selection_beta, + const std::string& difference_prior, const int iter, const int burnin, const bool na_impute, - const arma::imat& missing_data_indices,//updated + const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, - const bool difference_selection,//new - const arma::imat main_effect_indices, - const arma::imat pairwise_effect_indices, + const bool difference_selection, + const arma::imat& main_effect_indices, + const arma::imat& pairwise_effect_indices, const double target_accept, const int nuts_max_depth, const bool learn_mass_matrix, - const arma::mat projection,//new - const arma::ivec& group_membership,//new - const arma::imat& group_indices,//new - const arma::imat& interaction_index_matrix,//new + const arma::mat& projection, + const arma::ivec& group_membership, + const arma::imat& group_indices, + const arma::imat& interaction_index_matrix, arma::mat inclusion_probability, SafeRNG& rng ); \ No newline at end of file diff --git a/src/rng_utils.h b/src/rng_utils.h index c21ac83f..68dd0bd0 100644 --- a/src/rng_utils.h +++ b/src/rng_utils.h @@ -107,4 +107,4 @@ inline arma::uvec arma_randperm(SafeRNG& rng, arma::uword n) { std::iota(out.begin(), out.end(), 0); std::shuffle(out.begin(), out.end(), rng.eng); return out; -} +} \ No newline at end of file diff --git a/src/sampler_output.h b/src/sampler_output.h new file mode 100644 index 00000000..7c40ffc6 --- /dev/null +++ b/src/sampler_output.h @@ -0,0 +1,18 @@ +#ifndef SAMPLEROUTPUT_H +#define SAMPLEROUTPUT_H + +#include + +// Plain C++ struct, no Rcpp types +struct SamplerOutput { + arma::mat main_samples; + arma::mat pairwise_samples; + arma::imat indicator_samples; + arma::ivec treedepth_samples; + arma::ivec divergent_samples; + arma::vec energy_samples; + int chain_id; + bool has_indicator; +}; + +#endif diff --git a/tests/testthat/test-bgm.R b/tests/testthat/test-bgm.R index c54c034a..26a2d630 100644 --- a/tests/testthat/test-bgm.R +++ b/tests/testthat/test-bgm.R @@ -11,3 +11,52 @@ test_that("inclusion probabilities correlate with posterior mode", { testthat::expect_gte(cor(abs(posterior_modes), posterior_incl_probs, method = "spearman"), .9) }) + +on_ci <- isTRUE(as.logical(Sys.getenv("CI", "false"))) +no_cores <- if (on_ci) 2L else min(4, parallel::detectCores()) + +test_that("bgm is reproducible", { + data("Wenchuan", package = "bgms") + x <- Wenchuan[1:50, 1:5] + fit1 <- bgm(x = x, iter = 100, burnin = 1000, cores = no_cores, seed = 1234) + fit2 <- bgm(x = x, iter = 100, burnin = 1000, cores = no_cores, seed = 1234) + + testthat::expect_equal(fit1$raw_samples, fit2$raw_samples) +}) + +test_that("bgmCompare is reproducible", { + data("Wenchuan", package = "bgms") + x <- Wenchuan[1:50, 1:5] + y <- Wenchuan[1:50, c(1:4, 6)] + fit1 <- bgmCompare2(x = x, y = y, iter = 100, burnin = 1000, cores = no_cores, seed = 1234) + fit2 <- bgmCompare2(x = x, y = y, iter = 100, burnin = 1000, cores = no_cores, seed = 1234) + + combine_chains <- function(lst) { + # without abind + element_names <- names(lst[[1]]) + + out <- lapply(element_names, function(nm) { + arrays <- lapply(lst, function(chain) chain[[nm]]) + dims <- dim(arrays[[1]]) + if (is.null(dims)) { + # handle scalar/vector case (e.g. chain_id) + return(unlist(arrays)) + } else { + # bind along a new last dimension + arr <- array( + unlist(arrays), + dim = c(dims, length(arrays)) + ) + return(arr) + } + }) + + names(out) <- element_names + out + } + + combined1 <- combine_chains(samples1) + combined2 <- combine_chains(samples2) + + testthat::expect_equal(combined1, combined2) +})