Skip to content

Commit ad2386b

Browse files
authored
Merge pull request #54 from Bayesian-Graphical-Modelling-Lab/adaMala_debug
2 parents d005c5c + 779c83b commit ad2386b

17 files changed

+306
-245
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ S3method(print,summary.bgms)
1616
S3method(summary,bgms)
1717
export(bgm)
1818
export(bgmCompare)
19+
export(bgmCompare2)
1920
export(extract_arguments)
2021
export(extract_category_thresholds.bgms)
2122
export(extract_edge_indicators)

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

4-
run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads) {
5-
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads)
4+
run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed) {
5+
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed)
66
}
77

88
run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) {

R/bgmCompare2.R

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ bgmCompare2 = function(
196196
stop("Argument 'seed' must be a non-negative integer or vector of non-negative integers.")
197197
}
198198
# Force to integer type
199-
seed <- as.integer(seed)
199+
seed = as.integer(seed)
200200
dqrng::dqset.seed(seed)
201201
}
202202

@@ -225,44 +225,14 @@ bgmCompare2 = function(
225225
nuts_max_depth = nuts_max_depth,
226226
learn_mass_matrix = learn_mass_matrix,
227227
projection = projection,
228-
group_membership = sorted_group - 1, ######################################
228+
group_membership = sorted_group - 1,
229229
group_indices = group_indices,
230230
interaction_index_matrix = Index,
231231
inclusion_probability = model$inclusion_probability_difference,
232-
num_chains = chains, nThreads = cores
232+
num_chains = chains, nThreads = cores,
233+
seed = seed
233234
)
234235

235-
# out = run_gibbs_sampler_for_bgmCompare(
236-
# chain_id = 1,
237-
# observations = observations,
238-
# num_groups = num_groups,
239-
# num_obs_categories = num_obs_categories,
240-
# sufficient_blume_capel = sufficient_blume_capel,
241-
# sufficient_pairwise = sufficient_pairwise,
242-
# num_categories = num_categories[, 1],
243-
# main_alpha = threshold_alpha,
244-
# main_beta = threshold_beta,
245-
# pairwise_scale = interaction_scale,
246-
# difference_scale = difference_scale,
247-
# difference_selection_alpha = beta_bernoulli_alpha,
248-
# difference_selection_beta = beta_bernoulli_beta,
249-
# difference_prior = model$difference_prior, iter = iter, burnin = burnin,
250-
# na_impute = na_impute, missing_data_indices = missing_index,
251-
# is_ordinal_variable = ordinal_variable,
252-
# baseline_category = reference_category,
253-
# difference_selection = difference_selection,
254-
# main_effect_indices = main_effect_indices,
255-
# pairwise_effect_indices = pairwise_effect_indices,
256-
# target_accept = target_accept,
257-
# nuts_max_depth = nuts_max_depth,
258-
# learn_mass_matrix = learn_mass_matrix,
259-
# projection = projection,
260-
# group_membership = sorted_group - 1, ######################################
261-
# group_indices = group_indices,
262-
# interaction_index_matrix = Index,
263-
# inclusion_probability = model$inclusion_probability_difference)
264-
265-
266236
# Main output handler in the wrapper function
267237
# output = prepare_output_bgmCompare2(
268238
# out = out, ...

R/function_input_utils.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -737,8 +737,10 @@ check_compare2_model = function(
737737
if(difference_inclusion_probability >= 1)
738738
stop("The inclusion probability for differences cannot equal or exceed the value one.")
739739

740-
inclusion_probability_difference = inclusion_probability_difference +
741-
difference_probability
740+
inclusion_probability_difference = matrix(difference_probability,
741+
nrow = ncol(x),
742+
ncol = ncol(x))
743+
742744
} else {
743745
if(!inherits(difference_probability, what = "matrix") &&
744746
!inherits(difference_probability, what = "data.frame"))
@@ -765,8 +767,9 @@ check_compare2_model = function(
765767
stop("One or more inclusion probabilities for differences are one or larger.")
766768
}
767769
} else {
768-
inclusion_probability_difference = inclusion_probability_difference + 0.5
769-
770+
inclusion_probability_difference = matrix(0.5,
771+
nrow = ncol(x),
772+
ncol = ncol(x))
770773
if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0)
771774
stop("The scale parameters of the beta distribution for the differences need to be positive.")
772775
if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta))

man/bgm.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/summarySBM.Rd

Lines changed: 0 additions & 38 deletions
This file was deleted.

src/Makevars

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CXX_STD = CXX14
1+
CXX_STD = CXX20
22

33
## Pull in the include-paths for RcppParallel
44
PKG_CPPFLAGS = \
@@ -11,4 +11,4 @@ PKG_CPPFLAGS = \
1111
PKG_LIBS = \
1212
$(shell "${R_HOME}/bin/Rscript" -e "cat(RcppParallel::LdFlags())") \
1313
-Wl,-rpath,$(shell "${R_HOME}/bin/Rscript" -e "cat(system.file('lib',package='RcppParallel'))") \
14-
$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
14+
$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)

src/RcppExports.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,44 +12,45 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1212
#endif
1313

1414
// run_bgmCompare_parallel
15-
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, const int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, const double main_alpha, const double main_beta, const double pairwise_scale, const double difference_scale, const double difference_selection_alpha, const double difference_selection_beta, const std::string& difference_prior, const int iter, const int burnin, const bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, const double target_accept, const int nuts_max_depth, const bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, const int num_chains, const int nThreads);
16-
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP) {
15+
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed);
16+
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP) {
1717
BEGIN_RCPP
1818
Rcpp::RObject rcpp_result_gen;
1919
Rcpp::RNGScope rcpp_rngScope_gen;
2020
Rcpp::traits::input_parameter< const arma::imat& >::type observations(observationsSEXP);
21-
Rcpp::traits::input_parameter< const int >::type num_groups(num_groupsSEXP);
21+
Rcpp::traits::input_parameter< int >::type num_groups(num_groupsSEXP);
2222
Rcpp::traits::input_parameter< const std::vector<arma::imat>& >::type num_obs_categories(num_obs_categoriesSEXP);
2323
Rcpp::traits::input_parameter< const std::vector<arma::imat>& >::type sufficient_blume_capel(sufficient_blume_capelSEXP);
2424
Rcpp::traits::input_parameter< const std::vector<arma::mat>& >::type sufficient_pairwise(sufficient_pairwiseSEXP);
2525
Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP);
26-
Rcpp::traits::input_parameter< const double >::type main_alpha(main_alphaSEXP);
27-
Rcpp::traits::input_parameter< const double >::type main_beta(main_betaSEXP);
28-
Rcpp::traits::input_parameter< const double >::type pairwise_scale(pairwise_scaleSEXP);
29-
Rcpp::traits::input_parameter< const double >::type difference_scale(difference_scaleSEXP);
30-
Rcpp::traits::input_parameter< const double >::type difference_selection_alpha(difference_selection_alphaSEXP);
31-
Rcpp::traits::input_parameter< const double >::type difference_selection_beta(difference_selection_betaSEXP);
26+
Rcpp::traits::input_parameter< double >::type main_alpha(main_alphaSEXP);
27+
Rcpp::traits::input_parameter< double >::type main_beta(main_betaSEXP);
28+
Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP);
29+
Rcpp::traits::input_parameter< double >::type difference_scale(difference_scaleSEXP);
30+
Rcpp::traits::input_parameter< double >::type difference_selection_alpha(difference_selection_alphaSEXP);
31+
Rcpp::traits::input_parameter< double >::type difference_selection_beta(difference_selection_betaSEXP);
3232
Rcpp::traits::input_parameter< const std::string& >::type difference_prior(difference_priorSEXP);
33-
Rcpp::traits::input_parameter< const int >::type iter(iterSEXP);
34-
Rcpp::traits::input_parameter< const int >::type burnin(burninSEXP);
35-
Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP);
33+
Rcpp::traits::input_parameter< int >::type iter(iterSEXP);
34+
Rcpp::traits::input_parameter< int >::type burnin(burninSEXP);
35+
Rcpp::traits::input_parameter< bool >::type na_impute(na_imputeSEXP);
3636
Rcpp::traits::input_parameter< const arma::imat& >::type missing_data_indices(missing_data_indicesSEXP);
3737
Rcpp::traits::input_parameter< const arma::uvec& >::type is_ordinal_variable(is_ordinal_variableSEXP);
3838
Rcpp::traits::input_parameter< const arma::ivec& >::type baseline_category(baseline_categorySEXP);
39-
Rcpp::traits::input_parameter< const bool >::type difference_selection(difference_selectionSEXP);
39+
Rcpp::traits::input_parameter< bool >::type difference_selection(difference_selectionSEXP);
4040
Rcpp::traits::input_parameter< const arma::imat& >::type main_effect_indices(main_effect_indicesSEXP);
4141
Rcpp::traits::input_parameter< const arma::imat& >::type pairwise_effect_indices(pairwise_effect_indicesSEXP);
42-
Rcpp::traits::input_parameter< const double >::type target_accept(target_acceptSEXP);
43-
Rcpp::traits::input_parameter< const int >::type nuts_max_depth(nuts_max_depthSEXP);
44-
Rcpp::traits::input_parameter< const bool >::type learn_mass_matrix(learn_mass_matrixSEXP);
42+
Rcpp::traits::input_parameter< double >::type target_accept(target_acceptSEXP);
43+
Rcpp::traits::input_parameter< int >::type nuts_max_depth(nuts_max_depthSEXP);
44+
Rcpp::traits::input_parameter< bool >::type learn_mass_matrix(learn_mass_matrixSEXP);
4545
Rcpp::traits::input_parameter< const arma::mat& >::type projection(projectionSEXP);
4646
Rcpp::traits::input_parameter< const arma::ivec& >::type group_membership(group_membershipSEXP);
4747
Rcpp::traits::input_parameter< const arma::imat& >::type group_indices(group_indicesSEXP);
4848
Rcpp::traits::input_parameter< const arma::imat& >::type interaction_index_matrix(interaction_index_matrixSEXP);
4949
Rcpp::traits::input_parameter< const arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP);
50-
Rcpp::traits::input_parameter< const int >::type num_chains(num_chainsSEXP);
51-
Rcpp::traits::input_parameter< const int >::type nThreads(nThreadsSEXP);
52-
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads));
50+
Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP);
51+
Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP);
52+
Rcpp::traits::input_parameter< int >::type seed(seedSEXP);
53+
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed));
5354
return rcpp_result_gen;
5455
END_RCPP
5556
}
@@ -188,7 +189,7 @@ END_RCPP
188189
}
189190

190191
static const R_CallMethodDef CallEntries[] = {
191-
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 32},
192+
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 33},
192193
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 31},
193194
{"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6},
194195
{"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8},

0 commit comments

Comments
 (0)