Skip to content

Commit f1f8c78

Browse files
authored
fix parallel rng (#52)
1 parent 069ec48 commit f1f8c78

18 files changed

+91
-98
lines changed

src/bgmCompare_parallel.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
22
#include <RcppParallel.h>
33
#include <RcppArmadillo.h>
4-
#include <dqrng.h>
5-
#include <dqrng_generator.h>
6-
#include <xoshiro.h>
74
#include "bgmCompare_sampler.h"
85
#include <tbb/global_control.h>
96
#include <vector>
107
#include <string>
8+
#include "rng_utils.h"
119

1210
using namespace Rcpp;
1311
using namespace RcppParallel;
@@ -59,7 +57,7 @@ struct GibbsCompareChainRunner : public Worker {
5957
const arma::mat& inclusion_probability_master;
6058

6159
// RNG seeds
62-
const std::vector<uint64_t>& chain_seeds;
60+
const std::vector<SafeRNG>& chain_rngs;
6361

6462
// output
6563
std::vector<ChainResult>& results;
@@ -95,7 +93,7 @@ struct GibbsCompareChainRunner : public Worker {
9593
const arma::imat& group_indices,
9694
const arma::imat& interaction_index_matrix,
9795
const arma::mat& inclusion_probability_master,
98-
const std::vector<uint64_t>& chain_seeds,
96+
const std::vector<SafeRNG>& chain_rngs,
9997
std::vector<ChainResult>& results
10098
) :
10199
observations(observations),
@@ -128,7 +126,7 @@ struct GibbsCompareChainRunner : public Worker {
128126
group_indices(group_indices),
129127
interaction_index_matrix(interaction_index_matrix),
130128
inclusion_probability_master(inclusion_probability_master),
131-
chain_seeds(chain_seeds),
129+
chain_rngs(chain_rngs),
132130
results(results)
133131
{}
134132

@@ -140,7 +138,7 @@ struct GibbsCompareChainRunner : public Worker {
140138

141139
try {
142140
// per-chain RNG
143-
dqrng::xoshiro256plus rng(chain_seeds[i]);
141+
SafeRNG rng(chain_rngs[i]);
144142

145143
// make per-chain copies
146144
std::vector<arma::imat> num_obs_categories_cpp = num_obs_categories_cpp_master;
@@ -247,12 +245,12 @@ Rcpp::List run_bgmCompare_parallel(
247245
std::vector<ChainResult> results(num_chains);
248246

249247
// per-chain seeds
250-
std::vector<uint64_t> chain_seeds(num_chains);
251-
dqrng::xoshiro256plus seeder;
248+
std::vector<SafeRNG> chain_rngs(num_chains);
252249
for (int c = 0; c < num_chains; ++c) {
253-
chain_seeds[c] = seeder();
250+
chain_rngs[c] = SafeRNG(/*seed +*/ c); // TODO: this needs a seed passed by the user!
254251
}
255252

253+
256254
GibbsCompareChainRunner worker(
257255
observations, num_groups,
258256
num_obs_categories, sufficient_blume_capel, sufficient_pairwise,
@@ -262,7 +260,7 @@ Rcpp::List run_bgmCompare_parallel(
262260
baseline_category, difference_selection, main_effect_indices,
263261
pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix,
264262
projection, group_membership, group_indices, interaction_index_matrix,
265-
inclusion_probability, chain_seeds, results
263+
inclusion_probability, chain_rngs, results
266264
);
267265

268266
{

src/bgmCompare_sampler.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ List impute_missing_data_for_graphical_model(
6363
const arma::imat& missing_data_indices,
6464
const arma::uvec& is_ordinal_variable,
6565
const arma::ivec& baseline_category,
66-
dqrng::xoshiro256plus& rng
66+
SafeRNG& rng
6767
) {
6868
const int num_variables = observations.n_cols;
6969
const int num_missings = missing_data_indices.n_rows;
@@ -234,7 +234,7 @@ double find_reasonable_initial_step_size(
234234
const double main_alpha,
235235
const double main_beta,
236236
const double target_acceptance,
237-
dqrng::xoshiro256plus& rng
237+
SafeRNG& rng
238238
) {
239239
arma::vec theta = vectorize_model_parameters(
240240
main_effects, pairwise_effects, inclusion_indicator, main_effect_indices,
@@ -350,7 +350,7 @@ SamplerResult update_parameters_with_nuts(
350350
HMCAdaptationController& adapt,
351351
const bool learn_mass_matrix,
352352
const bool selection,
353-
dqrng::xoshiro256plus& rng
353+
SafeRNG& rng
354354
) {
355355
arma::vec current_state = vectorize_model_parameters(
356356
main_effects, pairwise_effects, inclusion_indicator,
@@ -613,7 +613,7 @@ void gibbs_update_step_for_graphical_model_parameters (
613613
const int num_groups,
614614
const arma::imat group_indices,
615615
double difference_scale,
616-
dqrng::xoshiro256plus& rng
616+
SafeRNG& rng
617617
) {
618618

619619
SamplerResult result = update_parameters_with_nuts(
@@ -670,7 +670,7 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare(
670670
const arma::imat& group_indices,//new
671671
const arma::imat& interaction_index_matrix,//new
672672
arma::mat inclusion_probability,//new
673-
dqrng::xoshiro256plus& rng
673+
SafeRNG& rng
674674
) {
675675
// --- Setup: dimensions and storage structures
676676
const int num_variables = observations.n_cols;

src/bgmCompare_sampler.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
2-
#include <dqrng.h>
3-
#include <xoshiro.h>
42
#include <RcppArmadillo.h>
53

4+
struct SafeRNG;
5+
66
Rcpp::List run_gibbs_sampler_for_bgmCompare(
77
int chain_id,
88
arma::imat observations,
@@ -35,5 +35,5 @@ Rcpp::List run_gibbs_sampler_for_bgmCompare(
3535
const arma::imat& group_indices,//new
3636
const arma::imat& interaction_index_matrix,//new
3737
arma::mat inclusion_probability,
38-
dqrng::xoshiro256plus& rng
38+
SafeRNG& rng
3939
);

src/bgm_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ inline void initialise_graph(
5050
const arma::mat& incl_prob,
5151
arma::mat& rest,
5252
const arma::imat& X,
53-
dqrng::xoshiro256plus& rng
53+
SafeRNG& rng
5454
) {
5555
int V = indicator.n_rows;
5656
for (int i = 0; i < V-1; ++i) {

src/bgm_parallel.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,18 @@
11
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
22
#include <RcppParallel.h>
33
#include <RcppArmadillo.h>
4-
#include <dqrng.h>
5-
#include <xoshiro.h>
64
#include "bgm_sampler.h"
75
#include <tbb/global_control.h>
86
#include <vector>
97
#include <string>
8+
#include "rng_utils.h"
109

1110
using namespace Rcpp;
1211
using namespace RcppParallel;
1312

1413
// -----------------------------------------------------------------------------
1514
// Wrapper to silence Clang warning
1615
// -----------------------------------------------------------------------------
17-
struct SafeRNG {
18-
dqrng::xoshiro256plus eng;
19-
20-
SafeRNG() : eng() {}
21-
SafeRNG(const dqrng::xoshiro256plus& other) : eng(other) {}
22-
SafeRNG& operator=(const dqrng::xoshiro256plus& other) {
23-
eng = other;
24-
return *this;
25-
}
26-
};
2716

2817
// -----------------------------------------------------------------------------
2918
// Result struct
@@ -146,7 +135,7 @@ struct GibbsChainRunner : public Worker {
146135
out.error = false;
147136

148137
try {
149-
dqrng::xoshiro256plus rng = chain_rngs[i].eng;
138+
SafeRNG rng = chain_rngs[i];
150139

151140
Rcpp::List result = run_gibbs_sampler_for_bgm(
152141
out.chain_id,
@@ -237,11 +226,8 @@ Rcpp::List run_bgm_parallel(
237226

238227
// Prepare one independent RNG per chain via jump()
239228
std::vector<SafeRNG> chain_rngs(num_chains);
240-
chain_rngs[0].eng = dqrng::xoshiro256plus(seed);
241-
242-
for (int c = 1; c < num_chains; ++c) {
243-
chain_rngs[c].eng = chain_rngs[c-1].eng;
244-
chain_rngs[c].eng.jump();
229+
for (int c = 0; c < num_chains; ++c) {
230+
chain_rngs[c] = SafeRNG(seed + c);
245231
}
246232

247233
GibbsChainRunner worker(

src/bgm_sampler.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void impute_missing_values_for_graphical_model (
5050
const arma::uvec& is_ordinal_variable,
5151
const arma::ivec& reference_category,
5252
arma::imat& sufficient_pairwise,
53-
dqrng::xoshiro256plus& rng
53+
SafeRNG& rng
5454
) {
5555
const int num_variables = observations.n_cols;
5656
const int num_missings = missing_index.n_rows;
@@ -193,7 +193,7 @@ double find_reasonable_initial_step_size(
193193
const double interaction_scale,
194194
const double target_acceptance,
195195
const arma::imat& sufficient_pairwise,
196-
dqrng::xoshiro256plus& rng
196+
SafeRNG& rng
197197
) {
198198
arma::vec theta = vectorize_model_parameters(
199199
main_effects, pairwise_effects, inclusion_indicator,
@@ -278,7 +278,7 @@ void update_main_effects_with_metropolis (
278278
arma::mat& proposal_sd_main,
279279
RWMAdaptationController& adapter,
280280
const int iteration,
281-
dqrng::xoshiro256plus& rng
281+
SafeRNG& rng
282282
) {
283283
const int num_vars = observations.n_cols;
284284
arma::umat index_mask_main = arma::ones<arma::umat>(proposal_sd_main.n_rows, proposal_sd_main.n_cols);
@@ -381,7 +381,7 @@ void update_pairwise_effects_with_metropolis (
381381
const arma::ivec& reference_category,
382382
const int iteration,
383383
const arma::imat& sufficient_pairwise,
384-
dqrng::xoshiro256plus& rng
384+
SafeRNG& rng
385385
) {
386386
arma::mat accept_prob_pairwise = arma::zeros<arma::mat>(num_variables, num_variables);
387387
arma::umat index_mask_pairwise = arma::zeros<arma::umat>(num_variables, num_variables);
@@ -478,7 +478,7 @@ void update_parameters_with_hmc(
478478
HMCAdaptationController& adapt,
479479
const bool learn_mass_matrix,
480480
const bool selection,
481-
dqrng::xoshiro256plus& rng
481+
SafeRNG& rng
482482
) {
483483
arma::vec current_state = vectorize_model_parameters(
484484
main_effects, pairwise_effects, inclusion_indicator,
@@ -587,7 +587,7 @@ SamplerResult update_parameters_with_nuts(
587587
HMCAdaptationController& adapt,
588588
const bool learn_mass_matrix,
589589
const bool selection,
590-
dqrng::xoshiro256plus& rng
590+
SafeRNG& rng
591591
) {
592592
arma::vec current_state = vectorize_model_parameters(
593593
main_effects, pairwise_effects, inclusion_indicator,
@@ -668,7 +668,7 @@ void tune_pairwise_proposal_sd(
668668
const arma::imat& sufficient_pairwise,
669669
int iteration,
670670
const WarmupSchedule& sched,
671-
dqrng::xoshiro256plus& rng,
671+
SafeRNG& rng,
672672
double target_accept = 0.44,
673673
double rm_decay = 0.75
674674
)
@@ -762,7 +762,7 @@ void update_indicator_interaction_pair_with_metropolis (
762762
const arma::uvec& is_ordinal_variable,
763763
const arma::ivec& reference_category,
764764
const arma::imat& sufficient_pairwise,
765-
dqrng::xoshiro256plus& rng
765+
SafeRNG& rng
766766
) {
767767
for (int cntr = 0; cntr < num_interactions; cntr++) {
768768
const int variable1 = index(cntr, 1);
@@ -912,7 +912,7 @@ void gibbs_update_step_for_graphical_model_parameters (
912912
arma::ivec& treedepth_samples,
913913
arma::ivec& divergent_samples,
914914
arma::vec& energy_samples,
915-
dqrng::xoshiro256plus& rng
915+
SafeRNG& rng
916916
) {
917917

918918
// Step 0: Initialise random graph structure when edge_selection = TRUE
@@ -1071,7 +1071,7 @@ Rcpp::List run_gibbs_sampler_for_bgm(
10711071
const int hmc_num_leapfrogs,
10721072
const int nuts_max_depth,
10731073
const bool learn_mass_matrix,
1074-
dqrng::xoshiro256plus& rng
1074+
SafeRNG& rng
10751075
) {
10761076
// --- Setup: dimensions and storage structures
10771077
const int num_variables = observations.n_cols;
@@ -1176,8 +1176,8 @@ Rcpp::List run_gibbs_sampler_for_bgm(
11761176
for (int iteration = 0; iteration < total_iter; iteration++) {
11771177
if (iteration % print_every == 0) {
11781178
tbb::mutex::scoped_lock lock(get_print_mutex());
1179-
//Rcpp::Rcout
11801179
std::cout
1180+
// Rcpp::Rcout
11811181
<< "[bgm] chain " << chain_id
11821182
<< " iteration " << iteration
11831183
<< " / " << total_iter

src/bgm_sampler.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#pragma once
2-
#include <dqrng.h>
3-
#include <xoshiro.h>
42
#include <RcppArmadillo.h>
53

4+
// forward declaration
5+
struct SafeRNG;
6+
67
Rcpp::List run_gibbs_sampler_for_bgm(
78
int chain_id,
89
arma::imat observations,
@@ -33,5 +34,5 @@ Rcpp::List run_gibbs_sampler_for_bgm(
3334
const int hmc_num_leapfrogs,
3435
const int nuts_max_depth,
3536
const bool learn_mass_matrix,
36-
dqrng::xoshiro256plus& rng
37+
SafeRNG& rng
3738
);

src/gibbs_functions_edge_prior.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ arma::uvec table_cpp(arma::uvec x) {
3333
arma::mat add_row_col_block_prob_matrix(arma::mat X,
3434
double beta_alpha,
3535
double beta_beta,
36-
dqrng::xoshiro256plus& rng) {
36+
SafeRNG& rng) {
3737
arma::uword dim = X.n_rows;
3838
arma::mat Y(dim+1,dim+1,arma::fill::zeros);
3939

@@ -159,7 +159,7 @@ inline void update_sumG(double &sumG,
159159
// Sample the cluster assignment in sample_block_allocations_mfm_sbm()
160160
// ----------------------------------------------------------------------------|
161161
arma::uword sample_cluster(arma::vec cluster_prob,
162-
dqrng::xoshiro256plus& rng) {
162+
SafeRNG& rng) {
163163
arma::vec cum_prob = arma::cumsum(cluster_prob);
164164
double u = runif(rng) * arma::max(cum_prob);
165165

@@ -182,7 +182,7 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign,
182182
arma::uword dirichlet_alpha,
183183
double beta_bernoulli_alpha,
184184
double beta_bernoulli_beta,
185-
dqrng::xoshiro256plus& rng) {
185+
SafeRNG& rng) {
186186
arma::uword old;
187187
arma::uword cluster;
188188
arma::uword no_clusters;
@@ -320,7 +320,7 @@ arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign,
320320
arma::uword no_variables,
321321
double beta_bernoulli_alpha,
322322
double beta_bernoulli_beta,
323-
dqrng::xoshiro256plus& rng) {
323+
SafeRNG& rng) {
324324

325325
arma::uvec cluster_size = table_cpp(cluster_assign);
326326
arma::uword no_clusters = cluster_size.n_elem;

src/gibbs_functions_edge_prior.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#pragma once
22

3-
#include "rng_utils.h"
43
#include <RcppArmadillo.h>
4+
struct SafeRNG;
5+
56

67
// ----------------------------------------------------------------------------|
78
// Compute partition coefficient for the MFM - SBM
@@ -22,7 +23,7 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign,
2223
arma::uword dirichlet_alpha,
2324
double beta_bernoulli_alpha,
2425
double beta_bernoulli_beta,
25-
dqrng::xoshiro256plus& rng);
26+
SafeRNG& rng);
2627

2728
// ----------------------------------------------------------------------------|
2829
// Sample the block parameters for the MFM - SBM
@@ -32,4 +33,4 @@ arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign,
3233
arma::uword no_variables,
3334
double beta_bernoulli_alpha,
3435
double beta_bernoulli_beta,
35-
dqrng::xoshiro256plus& rng);
36+
SafeRNG& rng);

src/mcmc_hmc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ SamplerResult hmc_sampler(
1515
const std::function<arma::vec(const arma::vec&)>& grad,
1616
const int num_leapfrogs,
1717
const arma::vec& inv_mass_diag,
18-
dqrng::xoshiro256plus& rng
18+
SafeRNG& rng
1919
) {
2020
arma::vec theta = init_theta;
2121
arma::vec init_r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem);

0 commit comments

Comments
 (0)