Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ BEGIN_RCPP
END_RCPP
}
// run_bgm_parallel
Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed, int progress_type);
Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, int seed, int progress_type);
RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Expand Down Expand Up @@ -93,7 +93,7 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< bool >::type learn_mass_matrix(learn_mass_matrixSEXP);
Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP);
Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP);
Rcpp::traits::input_parameter< uint64_t >::type seed(seedSEXP);
Rcpp::traits::input_parameter< int >::type seed(seedSEXP);
Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP);
rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type));
return rcpp_result_gen;
Expand Down
6 changes: 3 additions & 3 deletions src/bgmCompare_parallel.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
#include <RcppParallel.h>
#include <RcppArmadillo.h>
#include "bgmCompare_sampler.h"
#include "rng_utils.h" // must be included before RcppParallel
#include <RcppParallel.h>
#include <tbb/global_control.h>
#include <vector>
#include <string>
#include "rng_utils.h"
#include "progress_manager.h"
#include "sampler_output.h"
#include "mcmc_adaptation.h"
Expand Down Expand Up @@ -219,7 +219,7 @@ struct GibbsCompareChainRunner : public Worker {

try {
// per-chain RNG
SafeRNG rng(chain_rngs[i]);
SafeRNG rng = chain_rngs[i];

// make per-chain copies
std::vector<arma::imat> counts_per_category = counts_per_category_master;
Expand Down
2 changes: 0 additions & 2 deletions src/bgmCompare_sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "mcmc_nuts.h"
#include "mcmc_rwm.h"
#include "mcmc_utils.h"
#include "print_mutex.h"
#include "rng_utils.h"
#include "sampler_output.h"
#include "explog_switch.h"
Expand Down Expand Up @@ -1645,7 +1644,6 @@ SamplerOutput run_gibbs_sampler_bgmCompare(
);

const int total_iter = warmup_schedule.total_warmup + iter;
const int print_every = std::max(1, total_iter / 10);

// --- Main Gibbs sampling loop
bool userInterrupt = false;
Expand Down
6 changes: 3 additions & 3 deletions src/bgm_parallel.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
#include <RcppParallel.h>
#include <RcppArmadillo.h>
#include "rng_utils.h" // must be included before RcppParallel
#include <RcppParallel.h>
#include "bgm_sampler.h"
#include <tbb/global_control.h>
#include <vector>
#include <string>
#include "rng_utils.h"
#include "progress_manager.h"
#include "mcmc_adaptation.h"

Expand Down Expand Up @@ -297,7 +297,7 @@ Rcpp::List run_bgm_parallel(
bool learn_mass_matrix,
int num_chains,
int nThreads,
uint64_t seed,
int seed,
int progress_type
) {
std::vector<ChainResult> results(num_chains);
Expand Down
1 change: 0 additions & 1 deletion src/bgm_sampler.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <RcppArmadillo.h>
#include <Rcpp.h>
#include "bgm_helper.h"
#include "bgm_logp_and_grad.h"
#include "bgm_sampler.h"
Expand Down
1 change: 0 additions & 1 deletion src/mcmc_nuts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "mcmc_memoization.h"
#include "mcmc_nuts.h"
#include "mcmc_utils.h"
#include <Rcpp.h>
#include "rng_utils.h"
using namespace Rcpp;

Expand Down
6 changes: 5 additions & 1 deletion src/rng_utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// [[Rcpp::depends(BH)]]
#pragma once

// the order of these two is mandatory, RcppArmadillo must com before dqrng
#include <RcppArmadillo.h>
#include <dqrng.h>
#include <dqrng_generator.h>
Expand All @@ -9,6 +10,8 @@
#include <random>
#include <boost/random/beta_distribution.hpp>

// [[Rcpp::depends(dqrng, BH)]]

struct SafeRNG {
dqrng::xoshiro256plusplus eng;

Expand Down Expand Up @@ -107,4 +110,5 @@ inline arma::uvec arma_randperm(SafeRNG& rng, arma::uword n) {
std::iota(out.begin(), out.end(), 0);
std::shuffle(out.begin(), out.end(), rng.eng);
return out;
}
}

1 change: 0 additions & 1 deletion src/sbm_edge_prior.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <RcppArmadillo.h>
#include <Rcpp.h>
#include "rng_utils.h"
#include "explog_switch.h"

Expand Down
Loading