Skip to content

Commit 779c83b

Browse files
committed
fix runtime errors
1 parent f089795 commit 779c83b

File tree

9 files changed

+114
-81
lines changed

9 files changed

+114
-81
lines changed

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

4-
run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads) {
5-
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads)
4+
run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed) {
5+
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed)
66
}
77

88
run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) {

R/bgmCompare2.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ bgmCompare2 = function(
229229
group_indices = group_indices,
230230
interaction_index_matrix = Index,
231231
inclusion_probability = model$inclusion_probability_difference,
232-
num_chains = chains, nThreads = cores
232+
num_chains = chains, nThreads = cores,
233+
seed = seed
233234
)
234235

235236
# Main output handler in the wrapper function

src/RcppExports.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1212
#endif
1313

1414
// run_bgmCompare_parallel
15-
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads);
16-
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP) {
15+
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed);
16+
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP) {
1717
BEGIN_RCPP
1818
Rcpp::RObject rcpp_result_gen;
1919
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -49,7 +49,8 @@ BEGIN_RCPP
4949
Rcpp::traits::input_parameter< const arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP);
5050
Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP);
5151
Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP);
52-
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads));
52+
Rcpp::traits::input_parameter< int >::type seed(seedSEXP);
53+
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed));
5354
return rcpp_result_gen;
5455
END_RCPP
5556
}
@@ -188,7 +189,7 @@ END_RCPP
188189
}
189190

190191
static const R_CallMethodDef CallEntries[] = {
191-
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 32},
192+
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 33},
192193
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 31},
193194
{"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6},
194195
{"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8},

src/bgmCompare_logp_and_grad.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ double log_pseudoposterior(
1616
const arma::imat& observations,
1717
const arma::imat& group_indices,
1818
const arma::ivec& num_categories,
19-
const Rcpp::List& num_obs_categories_group,
20-
const Rcpp::List& sufficient_blume_capel_group,
21-
const Rcpp::List& sufficient_pairwise_group,
19+
const std::vector<arma::imat>& num_obs_categories_group,
20+
const std::vector<arma::imat>& sufficient_blume_capel_group,
21+
const std::vector<arma::mat>& sufficient_pairwise_group,
2222
const int num_groups,
2323
const arma::imat& inclusion_indicator,
2424
const arma::uvec& is_ordinal_variable,
@@ -34,8 +34,8 @@ double log_pseudoposterior(
3434

3535
// --- per group ---
3636
for (int group = 0; group < num_groups; ++group) {
37-
const arma::imat num_obs_categories = Rcpp::as<arma::imat>(num_obs_categories_group[group]);
38-
const arma::imat sufficient_blume_capel = Rcpp::as<arma::imat>(sufficient_blume_capel_group[group]);
37+
const arma::imat num_obs_categories = num_obs_categories_group[group];
38+
const arma::imat sufficient_blume_capel = sufficient_blume_capel_group[group];
3939

4040
arma::mat main_group(num_variables, max_num_categories, arma::fill::zeros);
4141
arma::mat pairwise_group(num_variables, num_variables, arma::fill::zeros);
@@ -81,7 +81,7 @@ double log_pseudoposterior(
8181
const int r0 = group_indices(group, 0);
8282
const int r1 = group_indices(group, 1);
8383
const arma::mat obs = arma::conv_to<arma::mat>::from(observations.rows(r0, r1));
84-
const arma::mat sufficient_pairwise = Rcpp::as<arma::mat>(sufficient_pairwise_group[group]);
84+
const arma::mat sufficient_pairwise = sufficient_pairwise_group[group];
8585

8686
log_pp += arma::accu(pairwise_group % sufficient_pairwise); // trace(X' * W * X) = sum(W %*% (X'X))
8787

@@ -170,9 +170,9 @@ arma::vec gradient(
170170
const arma::imat& observations,
171171
const arma::imat& group_indices,
172172
const arma::ivec& num_categories,
173-
const Rcpp::List& num_obs_categories_group,
174-
const Rcpp::List& sufficient_blume_capel_group,
175-
const Rcpp::List& sufficient_pairwise_group,
173+
const std::vector<arma::imat>& num_obs_categories_group,
174+
const std::vector<arma::imat>& sufficient_blume_capel_group,
175+
const std::vector<arma::mat>& sufficient_pairwise_group,
176176
const int num_groups,
177177
const arma::imat& inclusion_indicator,
178178
const arma::uvec& is_ordinal_variable,
@@ -214,10 +214,8 @@ arma::vec gradient(
214214
// -------------------------------
215215
for (int g = 0; g < num_groups; ++g) {
216216
// list access
217-
SEXP s1 = num_obs_categories_group[g];
218-
SEXP s2 = sufficient_blume_capel_group[g];
219-
arma::imat num_obs_categories = Rcpp::as<arma::imat>(s1);
220-
arma::imat sufficient_blume_capel = Rcpp::as<arma::imat>(s2);
217+
arma::imat num_obs_categories = num_obs_categories_group[g];
218+
arma::imat sufficient_blume_capel = sufficient_blume_capel_group[g];
221219

222220
// Main effects
223221
for (int v = 0; v < num_variables; ++v) {
@@ -260,9 +258,7 @@ arma::vec gradient(
260258
}
261259

262260
// Pairwise (observed)
263-
SEXP s3 = sufficient_pairwise_group[g];
264-
arma::mat sufficient_pairwise = Rcpp::as<arma::mat>(s3);
265-
261+
arma::mat sufficient_pairwise = sufficient_pairwise_group[g];
266262
for (int v1 = 0; v1 < num_variables - 1; ++v1) {
267263
for (int v2 = v1 + 1; v2 < num_variables; ++v2) {
268264
const int row = pairwise_effect_indices(v1, v2);

src/bgmCompare_logp_and_grad.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ double log_pseudoposterior(
1313
const arma::imat& observations,
1414
const arma::imat& group_indices,
1515
const arma::ivec& num_categories,
16-
const Rcpp::List& num_obs_categories_group,
17-
const Rcpp::List& sufficient_blume_capel_group,
18-
const Rcpp::List& sufficient_pairwise_group,
16+
const std::vector<arma::imat>& num_obs_categories,
17+
const std::vector<arma::imat>& sufficient_blume_capel,
18+
const std::vector<arma::mat>& sufficient_pairwise,
1919
const int num_groups,
2020
const arma::imat& inclusion_indicator,
2121
const arma::uvec& is_ordinal_variable,
@@ -35,9 +35,9 @@ arma::vec gradient(
3535
const arma::imat& observations,
3636
const arma::imat& group_indices,
3737
const arma::ivec& num_categories,
38-
const Rcpp::List& num_obs_categories_group,
39-
const Rcpp::List& sufficient_blume_capel_group,
40-
const Rcpp::List& sufficient_pairwise_group,
38+
const std::vector<arma::imat>& num_obs_categories,
39+
const std::vector<arma::imat>& sufficient_blume_capel,
40+
const std::vector<arma::mat>& sufficient_pairwise,
4141
const int num_groups,
4242
const arma::imat& inclusion_indicator,
4343
const arma::uvec& is_ordinal_variable,

src/bgmCompare_parallel.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,15 @@ Rcpp::List run_bgmCompare_parallel(
235235
const arma::imat& interaction_index_matrix,
236236
const arma::mat& inclusion_probability,
237237
int num_chains,
238-
int nThreads
238+
int nThreads,
239+
int seed
239240
) {
240241
std::vector<ChainResultCompare> results(num_chains);
241242

242243
// per-chain seeds
243244
std::vector<SafeRNG> chain_rngs(num_chains);
244245
for (int c = 0; c < num_chains; ++c) {
245-
chain_rngs[c] = SafeRNG(/*seed +*/ c); // TODO: this needs a seed passed by the user!
246+
chain_rngs[c] = SafeRNG(seed + c);
246247
}
247248

248249

0 commit comments

Comments
 (0)