diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 64c0084b..0d829f60 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -58,7 +58,7 @@ BEGIN_RCPP END_RCPP } // run_bgm_parallel -Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed, int progress_type); +Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, int seed, int progress_type); RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; @@ -93,7 +93,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< bool >::type learn_mass_matrix(learn_mass_matrixSEXP); Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP); Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP); - Rcpp::traits::input_parameter< uint64_t >::type seed(seedSEXP); + Rcpp::traits::input_parameter< int >::type seed(seedSEXP); Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)); return rcpp_result_gen; diff --git a/src/bgmCompare_parallel.cpp b/src/bgmCompare_parallel.cpp index fbdda065..9ded8dc0 100644 --- a/src/bgmCompare_parallel.cpp +++ b/src/bgmCompare_parallel.cpp @@ -1,11 +1,11 @@ // [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]] -#include #include #include "bgmCompare_sampler.h" +#include "rng_utils.h" // must be included before RcppParallel +#include #include #include #include -#include "rng_utils.h" #include "progress_manager.h" #include "sampler_output.h" #include "mcmc_adaptation.h" @@ -219,7 +219,7 @@ struct GibbsCompareChainRunner : public Worker { try { // per-chain RNG - SafeRNG rng(chain_rngs[i]); + SafeRNG rng = chain_rngs[i]; // make per-chain copies std::vector counts_per_category = counts_per_category_master; diff --git a/src/bgmCompare_sampler.cpp b/src/bgmCompare_sampler.cpp index d823249a..15ff19b3 100644 --- a/src/bgmCompare_sampler.cpp +++ b/src/bgmCompare_sampler.cpp @@ -9,7 +9,6 @@ #include "mcmc_nuts.h" #include "mcmc_rwm.h" #include "mcmc_utils.h" -#include "print_mutex.h" #include "rng_utils.h" #include "sampler_output.h" #include "explog_switch.h" @@ -1645,7 +1644,6 @@ SamplerOutput run_gibbs_sampler_bgmCompare( ); const int total_iter = warmup_schedule.total_warmup + iter; - const int print_every = std::max(1, total_iter / 10); // --- Main Gibbs sampling loop bool userInterrupt = false; diff --git a/src/bgm_parallel.cpp b/src/bgm_parallel.cpp index eb397cb1..7172d165 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_parallel.cpp @@ -1,11 +1,11 @@ // [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]] -#include #include +#include "rng_utils.h" // must be included before RcppParallel +#include #include "bgm_sampler.h" #include #include #include -#include "rng_utils.h" #include "progress_manager.h" #include "mcmc_adaptation.h" @@ -297,7 +297,7 @@ Rcpp::List run_bgm_parallel( bool learn_mass_matrix, int num_chains, int nThreads, - uint64_t seed, + int seed, int progress_type ) { std::vector results(num_chains); diff --git a/src/bgm_sampler.cpp b/src/bgm_sampler.cpp index 40b1099b..1c187ef0 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm_sampler.cpp @@ -1,5 +1,4 @@ #include -#include #include "bgm_helper.h" #include "bgm_logp_and_grad.h" #include "bgm_sampler.h" diff --git a/src/mcmc_nuts.cpp b/src/mcmc_nuts.cpp index 81f74d20..42d4abea 100644 --- a/src/mcmc_nuts.cpp +++ b/src/mcmc_nuts.cpp @@ -4,7 +4,6 @@ #include "mcmc_memoization.h" #include "mcmc_nuts.h" #include "mcmc_utils.h" -#include #include "rng_utils.h" using namespace Rcpp; diff --git a/src/rng_utils.h b/src/rng_utils.h index 68dd0bd0..ae9ed250 100644 --- a/src/rng_utils.h +++ b/src/rng_utils.h @@ -1,6 +1,7 @@ // [[Rcpp::depends(BH)]] #pragma once +// the order of these two is mandatory, RcppArmadillo must com before dqrng #include #include #include @@ -9,6 +10,8 @@ #include #include +// [[Rcpp::depends(dqrng, BH)]] + struct SafeRNG { dqrng::xoshiro256plusplus eng; @@ -107,4 +110,5 @@ inline arma::uvec arma_randperm(SafeRNG& rng, arma::uword n) { std::iota(out.begin(), out.end(), 0); std::shuffle(out.begin(), out.end(), rng.eng); return out; -} \ No newline at end of file +} + diff --git a/src/sbm_edge_prior.cpp b/src/sbm_edge_prior.cpp index a1f9dd22..e8cc3c0a 100644 --- a/src/sbm_edge_prior.cpp +++ b/src/sbm_edge_prior.cpp @@ -1,5 +1,4 @@ #include -#include #include "rng_utils.h" #include "explog_switch.h"