Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions src/bgmCompare_parallel.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
#include <RcppParallel.h>
#include <RcppArmadillo.h>
#include <dqrng.h>
#include <dqrng_generator.h>
#include <xoshiro.h>
#include "bgmCompare_sampler.h"
#include <tbb/global_control.h>
#include <vector>
#include <string>
#include "rng_utils.h"

using namespace Rcpp;
using namespace RcppParallel;
Expand Down Expand Up @@ -59,7 +57,7 @@ struct GibbsCompareChainRunner : public Worker {
const arma::mat& inclusion_probability_master;

// RNG seeds
const std::vector<uint64_t>& chain_seeds;
const std::vector<SafeRNG>& chain_rngs;

// output
std::vector<ChainResult>& results;
Expand Down Expand Up @@ -95,7 +93,7 @@ struct GibbsCompareChainRunner : public Worker {
const arma::imat& group_indices,
const arma::imat& interaction_index_matrix,
const arma::mat& inclusion_probability_master,
const std::vector<uint64_t>& chain_seeds,
const std::vector<SafeRNG>& chain_rngs,
std::vector<ChainResult>& results
) :
observations(observations),
Expand Down Expand Up @@ -128,7 +126,7 @@ struct GibbsCompareChainRunner : public Worker {
group_indices(group_indices),
interaction_index_matrix(interaction_index_matrix),
inclusion_probability_master(inclusion_probability_master),
chain_seeds(chain_seeds),
chain_rngs(chain_rngs),
results(results)
{}

Expand All @@ -140,7 +138,7 @@ struct GibbsCompareChainRunner : public Worker {

try {
// per-chain RNG
dqrng::xoshiro256plus rng(chain_seeds[i]);
SafeRNG rng(chain_rngs[i]);

// make per-chain copies
std::vector<arma::imat> num_obs_categories_cpp = num_obs_categories_cpp_master;
Expand Down Expand Up @@ -247,12 +245,12 @@ Rcpp::List run_bgmCompare_parallel(
std::vector<ChainResult> results(num_chains);

// per-chain seeds
std::vector<uint64_t> chain_seeds(num_chains);
dqrng::xoshiro256plus seeder;
std::vector<SafeRNG> chain_rngs(num_chains);
for (int c = 0; c < num_chains; ++c) {
chain_seeds[c] = seeder();
chain_rngs[c] = SafeRNG(/*seed +*/ c); // TODO: this needs a seed passed by the user!
}


GibbsCompareChainRunner worker(
observations, num_groups,
num_obs_categories, sufficient_blume_capel, sufficient_pairwise,
Expand All @@ -262,7 +260,7 @@ Rcpp::List run_bgmCompare_parallel(
baseline_category, difference_selection, main_effect_indices,
pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix,
projection, group_membership, group_indices, interaction_index_matrix,
inclusion_probability, chain_seeds, results
inclusion_probability, chain_rngs, results
);

{
Expand Down
10 changes: 5 additions & 5 deletions src/bgmCompare_sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ List impute_missing_data_for_graphical_model(
const arma::imat& missing_data_indices,
const arma::uvec& is_ordinal_variable,
const arma::ivec& baseline_category,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
const int num_variables = observations.n_cols;
const int num_missings = missing_data_indices.n_rows;
Expand Down Expand Up @@ -234,7 +234,7 @@ double find_reasonable_initial_step_size(
const double main_alpha,
const double main_beta,
const double target_acceptance,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
arma::vec theta = vectorize_model_parameters(
main_effects, pairwise_effects, inclusion_indicator, main_effect_indices,
Expand Down Expand Up @@ -350,7 +350,7 @@ SamplerResult update_parameters_with_nuts(
HMCAdaptationController& adapt,
const bool learn_mass_matrix,
const bool selection,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
arma::vec current_state = vectorize_model_parameters(
main_effects, pairwise_effects, inclusion_indicator,
Expand Down Expand Up @@ -613,7 +613,7 @@ void gibbs_update_step_for_graphical_model_parameters (
const int num_groups,
const arma::imat group_indices,
double difference_scale,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {

SamplerResult result = update_parameters_with_nuts(
Expand Down Expand Up @@ -670,7 +670,7 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare(
const arma::imat& group_indices,//new
const arma::imat& interaction_index_matrix,//new
arma::mat inclusion_probability,//new
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
// --- Setup: dimensions and storage structures
const int num_variables = observations.n_cols;
Expand Down
6 changes: 3 additions & 3 deletions src/bgmCompare_sampler.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once
#include <dqrng.h>
#include <xoshiro.h>
#include <RcppArmadillo.h>

struct SafeRNG;

Rcpp::List run_gibbs_sampler_for_bgmCompare(
int chain_id,
arma::imat observations,
Expand Down Expand Up @@ -35,5 +35,5 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare(
const arma::imat& group_indices,//new
const arma::imat& interaction_index_matrix,//new
arma::mat inclusion_probability,
dqrng::xoshiro256plus& rng
SafeRNG& rng
);
2 changes: 1 addition & 1 deletion src/bgm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ inline void initialise_graph(
const arma::mat& incl_prob,
arma::mat& rest,
const arma::imat& X,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
int V = indicator.n_rows;
for (int i = 0; i < V-1; ++i) {
Expand Down
22 changes: 4 additions & 18 deletions src/bgm_parallel.cpp
Original file line number Diff line number Diff line change
@@ -1,29 +1,18 @@
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
#include <RcppParallel.h>
#include <RcppArmadillo.h>
#include <dqrng.h>
#include <xoshiro.h>
#include "bgm_sampler.h"
#include <tbb/global_control.h>
#include <vector>
#include <string>
#include "rng_utils.h"

using namespace Rcpp;
using namespace RcppParallel;

// -----------------------------------------------------------------------------
// Wrapper to silence Clang warning
// -----------------------------------------------------------------------------
struct SafeRNG {
dqrng::xoshiro256plus eng;

SafeRNG() : eng() {}
SafeRNG(const dqrng::xoshiro256plus& other) : eng(other) {}
SafeRNG& operator=(const dqrng::xoshiro256plus& other) {
eng = other;
return *this;
}
};

// -----------------------------------------------------------------------------
// Result struct
Expand Down Expand Up @@ -146,7 +135,7 @@ struct GibbsChainRunner : public Worker {
out.error = false;

try {
dqrng::xoshiro256plus rng = chain_rngs[i].eng;
SafeRNG rng = chain_rngs[i];

Rcpp::List result = run_gibbs_sampler_for_bgm(
out.chain_id,
Expand Down Expand Up @@ -237,11 +226,8 @@ Rcpp::List run_bgm_parallel(

// Prepare one independent RNG per chain via jump()
std::vector<SafeRNG> chain_rngs(num_chains);
chain_rngs[0].eng = dqrng::xoshiro256plus(seed);

for (int c = 1; c < num_chains; ++c) {
chain_rngs[c].eng = chain_rngs[c-1].eng;
chain_rngs[c].eng.jump();
for (int c = 0; c < num_chains; ++c) {
chain_rngs[c] = SafeRNG(seed + c);
}

GibbsChainRunner worker(
Expand Down
22 changes: 11 additions & 11 deletions src/bgm_sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void impute_missing_values_for_graphical_model (
const arma::uvec& is_ordinal_variable,
const arma::ivec& reference_category,
arma::imat& sufficient_pairwise,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
const int num_variables = observations.n_cols;
const int num_missings = missing_index.n_rows;
Expand Down Expand Up @@ -193,7 +193,7 @@ double find_reasonable_initial_step_size(
const double interaction_scale,
const double target_acceptance,
const arma::imat& sufficient_pairwise,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
arma::vec theta = vectorize_model_parameters(
main_effects, pairwise_effects, inclusion_indicator,
Expand Down Expand Up @@ -278,7 +278,7 @@ void update_main_effects_with_metropolis (
arma::mat& proposal_sd_main,
RWMAdaptationController& adapter,
const int iteration,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
const int num_vars = observations.n_cols;
arma::umat index_mask_main = arma::ones<arma::umat>(proposal_sd_main.n_rows, proposal_sd_main.n_cols);
Expand Down Expand Up @@ -381,7 +381,7 @@ void update_pairwise_effects_with_metropolis (
const arma::ivec& reference_category,
const int iteration,
const arma::imat& sufficient_pairwise,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
arma::mat accept_prob_pairwise = arma::zeros<arma::mat>(num_variables, num_variables);
arma::umat index_mask_pairwise = arma::zeros<arma::umat>(num_variables, num_variables);
Expand Down Expand Up @@ -478,7 +478,7 @@ void update_parameters_with_hmc(
HMCAdaptationController& adapt,
const bool learn_mass_matrix,
const bool selection,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
arma::vec current_state = vectorize_model_parameters(
main_effects, pairwise_effects, inclusion_indicator,
Expand Down Expand Up @@ -587,7 +587,7 @@ SamplerResult update_parameters_with_nuts(
HMCAdaptationController& adapt,
const bool learn_mass_matrix,
const bool selection,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
arma::vec current_state = vectorize_model_parameters(
main_effects, pairwise_effects, inclusion_indicator,
Expand Down Expand Up @@ -668,7 +668,7 @@ void tune_pairwise_proposal_sd(
const arma::imat& sufficient_pairwise,
int iteration,
const WarmupSchedule& sched,
dqrng::xoshiro256plus& rng,
SafeRNG& rng,
double target_accept = 0.44,
double rm_decay = 0.75
)
Expand Down Expand Up @@ -762,7 +762,7 @@ void update_indicator_interaction_pair_with_metropolis (
const arma::uvec& is_ordinal_variable,
const arma::ivec& reference_category,
const arma::imat& sufficient_pairwise,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
for (int cntr = 0; cntr < num_interactions; cntr++) {
const int variable1 = index(cntr, 1);
Expand Down Expand Up @@ -912,7 +912,7 @@ void gibbs_update_step_for_graphical_model_parameters (
arma::ivec& treedepth_samples,
arma::ivec& divergent_samples,
arma::vec& energy_samples,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {

// Step 0: Initialise random graph structure when edge_selection = TRUE
Expand Down Expand Up @@ -1071,7 +1071,7 @@ Rcpp::List run_gibbs_sampler_for_bgm(
const int hmc_num_leapfrogs,
const int nuts_max_depth,
const bool learn_mass_matrix,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
// --- Setup: dimensions and storage structures
const int num_variables = observations.n_cols;
Expand Down Expand Up @@ -1176,8 +1176,8 @@ Rcpp::List run_gibbs_sampler_for_bgm(
for (int iteration = 0; iteration < total_iter; iteration++) {
if (iteration % print_every == 0) {
tbb::mutex::scoped_lock lock(get_print_mutex());
//Rcpp::Rcout
std::cout
// Rcpp::Rcout
<< "[bgm] chain " << chain_id
<< " iteration " << iteration
<< " / " << total_iter
Expand Down
7 changes: 4 additions & 3 deletions src/bgm_sampler.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#pragma once
#include <dqrng.h>
#include <xoshiro.h>
#include <RcppArmadillo.h>

// forward declaration
struct SafeRNG;

Rcpp::List run_gibbs_sampler_for_bgm(
int chain_id,
arma::imat observations,
Expand Down Expand Up @@ -33,5 +34,5 @@ Rcpp::List run_gibbs_sampler_for_bgm(
const int hmc_num_leapfrogs,
const int nuts_max_depth,
const bool learn_mass_matrix,
dqrng::xoshiro256plus& rng
SafeRNG& rng
);
8 changes: 4 additions & 4 deletions src/gibbs_functions_edge_prior.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ arma::uvec table_cpp(arma::uvec x) {
arma::mat add_row_col_block_prob_matrix(arma::mat X,
double beta_alpha,
double beta_beta,
dqrng::xoshiro256plus& rng) {
SafeRNG& rng) {
arma::uword dim = X.n_rows;
arma::mat Y(dim+1,dim+1,arma::fill::zeros);

Expand Down Expand Up @@ -159,7 +159,7 @@ inline void update_sumG(double &sumG,
// Sample the cluster assignment in sample_block_allocations_mfm_sbm()
// ----------------------------------------------------------------------------|
arma::uword sample_cluster(arma::vec cluster_prob,
dqrng::xoshiro256plus& rng) {
SafeRNG& rng) {
arma::vec cum_prob = arma::cumsum(cluster_prob);
double u = runif(rng) * arma::max(cum_prob);

Expand All @@ -182,7 +182,7 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign,
arma::uword dirichlet_alpha,
double beta_bernoulli_alpha,
double beta_bernoulli_beta,
dqrng::xoshiro256plus& rng) {
SafeRNG& rng) {
arma::uword old;
arma::uword cluster;
arma::uword no_clusters;
Expand Down Expand Up @@ -320,7 +320,7 @@ arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign,
arma::uword no_variables,
double beta_bernoulli_alpha,
double beta_bernoulli_beta,
dqrng::xoshiro256plus& rng) {
SafeRNG& rng) {

arma::uvec cluster_size = table_cpp(cluster_assign);
arma::uword no_clusters = cluster_size.n_elem;
Expand Down
7 changes: 4 additions & 3 deletions src/gibbs_functions_edge_prior.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include "rng_utils.h"
#include <RcppArmadillo.h>
struct SafeRNG;


// ----------------------------------------------------------------------------|
// Compute partition coefficient for the MFM - SBM
Expand All @@ -22,7 +23,7 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign,
arma::uword dirichlet_alpha,
double beta_bernoulli_alpha,
double beta_bernoulli_beta,
dqrng::xoshiro256plus& rng);
SafeRNG& rng);

// ----------------------------------------------------------------------------|
// Sample the block parameters for the MFM - SBM
Expand All @@ -32,4 +33,4 @@ arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign,
arma::uword no_variables,
double beta_bernoulli_alpha,
double beta_bernoulli_beta,
dqrng::xoshiro256plus& rng);
SafeRNG& rng);
2 changes: 1 addition & 1 deletion src/mcmc_hmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ SamplerResult hmc_sampler(
const std::function<arma::vec(const arma::vec&)>& grad,
const int num_leapfrogs,
const arma::vec& inv_mass_diag,
dqrng::xoshiro256plus& rng
SafeRNG& rng
) {
arma::vec theta = init_theta;
arma::vec init_r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem);
Expand Down
Loading