Skip to content

Commit 669768a

Browse files
Fix for stack imbalance bgmCompare parallel
1 parent 069ec48 commit 669768a

13 files changed

+237
-209
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ S3method(print,summary.bgms)
1616
S3method(summary,bgms)
1717
export(bgm)
1818
export(bgmCompare)
19+
export(bgmCompare2)
1920
export(extract_arguments)
2021
export(extract_category_thresholds.bgms)
2122
export(extract_edge_indicators)

R/bgmCompare2.R

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ bgmCompare2 = function(
196196
stop("Argument 'seed' must be a non-negative integer or vector of non-negative integers.")
197197
}
198198
# Force to integer type
199-
seed <- as.integer(seed)
199+
seed = as.integer(seed)
200200
dqrng::dqset.seed(seed)
201201
}
202202

@@ -225,44 +225,13 @@ bgmCompare2 = function(
225225
nuts_max_depth = nuts_max_depth,
226226
learn_mass_matrix = learn_mass_matrix,
227227
projection = projection,
228-
group_membership = sorted_group - 1, ######################################
228+
group_membership = sorted_group - 1,
229229
group_indices = group_indices,
230230
interaction_index_matrix = Index,
231231
inclusion_probability = model$inclusion_probability_difference,
232232
num_chains = chains, nThreads = cores
233233
)
234234

235-
# out = run_gibbs_sampler_for_bgmCompare(
236-
# chain_id = 1,
237-
# observations = observations,
238-
# num_groups = num_groups,
239-
# num_obs_categories = num_obs_categories,
240-
# sufficient_blume_capel = sufficient_blume_capel,
241-
# sufficient_pairwise = sufficient_pairwise,
242-
# num_categories = num_categories[, 1],
243-
# main_alpha = threshold_alpha,
244-
# main_beta = threshold_beta,
245-
# pairwise_scale = interaction_scale,
246-
# difference_scale = difference_scale,
247-
# difference_selection_alpha = beta_bernoulli_alpha,
248-
# difference_selection_beta = beta_bernoulli_beta,
249-
# difference_prior = model$difference_prior, iter = iter, burnin = burnin,
250-
# na_impute = na_impute, missing_data_indices = missing_index,
251-
# is_ordinal_variable = ordinal_variable,
252-
# baseline_category = reference_category,
253-
# difference_selection = difference_selection,
254-
# main_effect_indices = main_effect_indices,
255-
# pairwise_effect_indices = pairwise_effect_indices,
256-
# target_accept = target_accept,
257-
# nuts_max_depth = nuts_max_depth,
258-
# learn_mass_matrix = learn_mass_matrix,
259-
# projection = projection,
260-
# group_membership = sorted_group - 1, ######################################
261-
# group_indices = group_indices,
262-
# interaction_index_matrix = Index,
263-
# inclusion_probability = model$inclusion_probability_difference)
264-
265-
266235
# Main output handler in the wrapper function
267236
# output = prepare_output_bgmCompare2(
268237
# out = out, ...

R/data_utils.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ compute_sufficient_blume_capel = function(x, reference_category, ordinal_variabl
490490

491491
# Helper function for computing sufficient statistics for pairwise interactions
492492
compute_sufficient_pairwise <- function(x, group) {
493-
result <- vector("list", length(group))
493+
result <- vector("list", length(unique(group)))
494494

495495
for (g in unique(group)) {
496496
obs <- x[group == g, , drop = FALSE]

R/function_input_utils.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -737,8 +737,10 @@ check_compare2_model = function(
737737
if(difference_inclusion_probability >= 1)
738738
stop("The inclusion probability for differences cannot equal or exceed the value one.")
739739

740-
inclusion_probability_difference = inclusion_probability_difference +
741-
difference_probability
740+
inclusion_probability_difference = matrix(difference_probability,
741+
nrow = ncol(x),
742+
ncol = ncol(x))
743+
742744
} else {
743745
if(!inherits(difference_probability, what = "matrix") &&
744746
!inherits(difference_probability, what = "data.frame"))
@@ -765,8 +767,9 @@ check_compare2_model = function(
765767
stop("One or more inclusion probabilities for differences are one or larger.")
766768
}
767769
} else {
768-
inclusion_probability_difference = inclusion_probability_difference + 0.5
769-
770+
inclusion_probability_difference = matrix(0.5,
771+
nrow = ncol(x),
772+
ncol = ncol(x))
770773
if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0)
771774
stop("The scale parameters of the beta distribution for the differences need to be positive.")
772775
if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta))

man/bgm.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/summarySBM.Rd

Lines changed: 0 additions & 38 deletions
This file was deleted.

src/Makevars

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CXX_STD = CXX14
1+
CXX_STD = CXX20
22

33
## Pull in the include-paths for RcppParallel
44
PKG_CPPFLAGS = \
@@ -11,4 +11,4 @@ PKG_CPPFLAGS = \
1111
PKG_LIBS = \
1212
$(shell "${R_HOME}/bin/Rscript" -e "cat(RcppParallel::LdFlags())") \
1313
-Wl,-rpath,$(shell "${R_HOME}/bin/Rscript" -e "cat(system.file('lib',package='RcppParallel'))") \
14-
$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
14+
$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)

src/RcppExports.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,43 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1212
#endif
1313

1414
// run_bgmCompare_parallel
15-
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, const int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, const double main_alpha, const double main_beta, const double pairwise_scale, const double difference_scale, const double difference_selection_alpha, const double difference_selection_beta, const std::string& difference_prior, const int iter, const int burnin, const bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, const double target_accept, const int nuts_max_depth, const bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, const int num_chains, const int nThreads);
15+
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads);
1616
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP) {
1717
BEGIN_RCPP
1818
Rcpp::RObject rcpp_result_gen;
1919
Rcpp::RNGScope rcpp_rngScope_gen;
2020
Rcpp::traits::input_parameter< const arma::imat& >::type observations(observationsSEXP);
21-
Rcpp::traits::input_parameter< const int >::type num_groups(num_groupsSEXP);
21+
Rcpp::traits::input_parameter< int >::type num_groups(num_groupsSEXP);
2222
Rcpp::traits::input_parameter< const std::vector<arma::imat>& >::type num_obs_categories(num_obs_categoriesSEXP);
2323
Rcpp::traits::input_parameter< const std::vector<arma::imat>& >::type sufficient_blume_capel(sufficient_blume_capelSEXP);
2424
Rcpp::traits::input_parameter< const std::vector<arma::mat>& >::type sufficient_pairwise(sufficient_pairwiseSEXP);
2525
Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP);
26-
Rcpp::traits::input_parameter< const double >::type main_alpha(main_alphaSEXP);
27-
Rcpp::traits::input_parameter< const double >::type main_beta(main_betaSEXP);
28-
Rcpp::traits::input_parameter< const double >::type pairwise_scale(pairwise_scaleSEXP);
29-
Rcpp::traits::input_parameter< const double >::type difference_scale(difference_scaleSEXP);
30-
Rcpp::traits::input_parameter< const double >::type difference_selection_alpha(difference_selection_alphaSEXP);
31-
Rcpp::traits::input_parameter< const double >::type difference_selection_beta(difference_selection_betaSEXP);
26+
Rcpp::traits::input_parameter< double >::type main_alpha(main_alphaSEXP);
27+
Rcpp::traits::input_parameter< double >::type main_beta(main_betaSEXP);
28+
Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP);
29+
Rcpp::traits::input_parameter< double >::type difference_scale(difference_scaleSEXP);
30+
Rcpp::traits::input_parameter< double >::type difference_selection_alpha(difference_selection_alphaSEXP);
31+
Rcpp::traits::input_parameter< double >::type difference_selection_beta(difference_selection_betaSEXP);
3232
Rcpp::traits::input_parameter< const std::string& >::type difference_prior(difference_priorSEXP);
33-
Rcpp::traits::input_parameter< const int >::type iter(iterSEXP);
34-
Rcpp::traits::input_parameter< const int >::type burnin(burninSEXP);
35-
Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP);
33+
Rcpp::traits::input_parameter< int >::type iter(iterSEXP);
34+
Rcpp::traits::input_parameter< int >::type burnin(burninSEXP);
35+
Rcpp::traits::input_parameter< bool >::type na_impute(na_imputeSEXP);
3636
Rcpp::traits::input_parameter< const arma::imat& >::type missing_data_indices(missing_data_indicesSEXP);
3737
Rcpp::traits::input_parameter< const arma::uvec& >::type is_ordinal_variable(is_ordinal_variableSEXP);
3838
Rcpp::traits::input_parameter< const arma::ivec& >::type baseline_category(baseline_categorySEXP);
39-
Rcpp::traits::input_parameter< const bool >::type difference_selection(difference_selectionSEXP);
39+
Rcpp::traits::input_parameter< bool >::type difference_selection(difference_selectionSEXP);
4040
Rcpp::traits::input_parameter< const arma::imat& >::type main_effect_indices(main_effect_indicesSEXP);
4141
Rcpp::traits::input_parameter< const arma::imat& >::type pairwise_effect_indices(pairwise_effect_indicesSEXP);
42-
Rcpp::traits::input_parameter< const double >::type target_accept(target_acceptSEXP);
43-
Rcpp::traits::input_parameter< const int >::type nuts_max_depth(nuts_max_depthSEXP);
44-
Rcpp::traits::input_parameter< const bool >::type learn_mass_matrix(learn_mass_matrixSEXP);
42+
Rcpp::traits::input_parameter< double >::type target_accept(target_acceptSEXP);
43+
Rcpp::traits::input_parameter< int >::type nuts_max_depth(nuts_max_depthSEXP);
44+
Rcpp::traits::input_parameter< bool >::type learn_mass_matrix(learn_mass_matrixSEXP);
4545
Rcpp::traits::input_parameter< const arma::mat& >::type projection(projectionSEXP);
4646
Rcpp::traits::input_parameter< const arma::ivec& >::type group_membership(group_membershipSEXP);
4747
Rcpp::traits::input_parameter< const arma::imat& >::type group_indices(group_indicesSEXP);
4848
Rcpp::traits::input_parameter< const arma::imat& >::type interaction_index_matrix(interaction_index_matrixSEXP);
4949
Rcpp::traits::input_parameter< const arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP);
50-
Rcpp::traits::input_parameter< const int >::type num_chains(num_chainsSEXP);
51-
Rcpp::traits::input_parameter< const int >::type nThreads(nThreadsSEXP);
50+
Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP);
51+
Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP);
5252
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads));
5353
return rcpp_result_gen;
5454
END_RCPP

src/bgmCompare_helper.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <RcppArmadillo.h>
4+
#include "rng_utils.h"
45

56

67

@@ -65,4 +66,40 @@ arma::vec inv_mass_active(
6566
const arma::imat& main_effect_indices,
6667
const arma::imat& pairwise_effect_indices,
6768
const bool& selection
68-
);
69+
);
70+
71+
inline void initialise_graph(
72+
arma::imat& indicator,
73+
arma::mat& main,
74+
arma::mat& pairwise,
75+
const arma::imat& main_indices,
76+
const arma::imat& pairwise_indices,
77+
const arma::mat& incl_prob,
78+
dqrng::xoshiro256plus& rng
79+
) {
80+
int V = indicator.n_rows;
81+
int G = main.n_cols;
82+
for (int i = 0; i < V-1; ++i) {
83+
for (int j = i+1; j < V; ++j) {
84+
double p = incl_prob(i,j);
85+
int draw = (runif(rng) < p) ? 1 : 0;
86+
indicator(i,j) = indicator(j,i) = draw;
87+
if (!draw) {
88+
int row = pairwise_indices(i, j);
89+
pairwise.row(row).cols(1, G-1) = 0.0;
90+
}
91+
}
92+
}
93+
for(int i = 0; i < V; i++) {
94+
double p = incl_prob(i,i);
95+
int draw = (runif(rng) < p) ? 1 : 0;
96+
indicator(i,i) = draw;
97+
if(!draw) {
98+
int start = main_indices(i,0);
99+
int end = main_indices(i,1);
100+
for(int row = start; row < end; row++) {
101+
main.row(row).cols(1, G-1) = 0.0;
102+
}
103+
}
104+
}
105+
};

0 commit comments

Comments
 (0)