Skip to content

Commit 028ed9e

Browse files
authored
Merge branch 'adaMala' into adaMala_debug
2 parents 4144345 + d005c5c commit 028ed9e

19 files changed

+96
-102
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ URL: https://maartenmarsman.github.io/bgms/
2222
BugReports: https://github.com/MaartenMarsman/bgms/issues
2323
Imports: Rcpp (>= 1.0.7), RcppParallel, Rdpack, methods, coda, dqrng
2424
RdMacros: Rdpack
25-
LinkingTo: Rcpp, RcppProgress, RcppArmadillo, RcppParallel, dqrng
25+
LinkingTo: Rcpp, RcppProgress, RcppArmadillo, RcppParallel, dqrng, BH
2626
RoxygenNote: 7.3.2
2727
Depends:
2828
R (>= 2.10)

src/bgmCompare_parallel.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
22
#include <RcppParallel.h>
33
#include <RcppArmadillo.h>
4-
#include <dqrng.h>
5-
#include <dqrng_generator.h>
6-
#include <xoshiro.h>
4+
#include "bgmCompare_sampler.h"
75
#include <tbb/global_control.h>
86
#include <vector>
97
#include <string>
10-
#include "sampler_output.h"
11-
#include "bgmCompare_sampler.h"
8+
#include "rng_utils.h"
129

1310
using namespace Rcpp;
1411
using namespace RcppParallel;
@@ -60,7 +57,7 @@ struct GibbsCompareChainRunner : public Worker {
6057
const arma::mat& inclusion_probability_master;
6158

6259
// RNG seeds
63-
const std::vector<uint64_t>& chain_seeds;
60+
const std::vector<SafeRNG>& chain_rngs;
6461

6562
// output
6663
std::vector<ChainResultCompare>& results;
@@ -96,7 +93,7 @@ struct GibbsCompareChainRunner : public Worker {
9693
const arma::imat& group_indices,
9794
const arma::imat& interaction_index_matrix,
9895
const arma::mat& inclusion_probability_master,
99-
const std::vector<uint64_t>& chain_seeds,
96+
const std::vector<SafeRNG>& chain_rngs,
10097
std::vector<ChainResultCompare>& results
10198
) :
10299
observations(observations),
@@ -129,7 +126,7 @@ struct GibbsCompareChainRunner : public Worker {
129126
group_indices(group_indices),
130127
interaction_index_matrix(interaction_index_matrix),
131128
inclusion_probability_master(inclusion_probability_master),
132-
chain_seeds(chain_seeds),
129+
chain_rngs(chain_rngs),
133130
results(results)
134131
{}
135132

@@ -140,7 +137,8 @@ struct GibbsCompareChainRunner : public Worker {
140137
out.error = false;
141138

142139
try {
143-
dqrng::xoshiro256plus rng(chain_seeds[i]);
140+
// per-chain RNG
141+
SafeRNG rng(chain_rngs[i]);
144142

145143
// make per-chain copies
146144
std::vector<arma::imat> num_obs_categories = num_obs_categories_master;
@@ -242,12 +240,12 @@ Rcpp::List run_bgmCompare_parallel(
242240
std::vector<ChainResultCompare> results(num_chains);
243241

244242
// per-chain seeds
245-
std::vector<uint64_t> chain_seeds(num_chains);
246-
dqrng::xoshiro256plus seeder;
243+
std::vector<SafeRNG> chain_rngs(num_chains);
247244
for (int c = 0; c < num_chains; ++c) {
248-
chain_seeds[c] = seeder();
245+
chain_rngs[c] = SafeRNG(/*seed +*/ c); // TODO: this needs a seed passed by the user!
249246
}
250247

248+
251249
GibbsCompareChainRunner worker(
252250
observations, num_groups,
253251
num_obs_categories, sufficient_blume_capel, sufficient_pairwise,
@@ -257,7 +255,7 @@ Rcpp::List run_bgmCompare_parallel(
257255
baseline_category, difference_selection, main_effect_indices,
258256
pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix,
259257
projection, group_membership, group_indices, interaction_index_matrix,
260-
inclusion_probability, chain_seeds, results
258+
inclusion_probability, chain_rngs, results
261259
);
262260

263261
{

src/bgmCompare_sampler.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ List impute_missing_data_for_graphical_model(
6464
const arma::imat& missing_data_indices,
6565
const arma::uvec& is_ordinal_variable,
6666
const arma::ivec& baseline_category,
67-
dqrng::xoshiro256plus& rng
67+
SafeRNG& rng
6868
) {
6969
const int num_variables = observations.n_cols;
7070
const int num_missings = missing_data_indices.n_rows;
@@ -235,7 +235,7 @@ double find_reasonable_initial_step_size(
235235
const double main_alpha,
236236
const double main_beta,
237237
const double target_acceptance,
238-
dqrng::xoshiro256plus& rng
238+
SafeRNG& rng
239239
) {
240240
arma::vec theta = vectorize_model_parameters(
241241
main_effects, pairwise_effects, inclusion_indicator, main_effect_indices,
@@ -351,7 +351,7 @@ SamplerResult update_parameters_with_nuts(
351351
HMCAdaptationController& adapt,
352352
const bool learn_mass_matrix,
353353
const bool selection,
354-
dqrng::xoshiro256plus& rng
354+
SafeRNG& rng
355355
) {
356356
arma::vec current_state = vectorize_model_parameters(
357357
main_effects, pairwise_effects, inclusion_indicator,
@@ -614,7 +614,7 @@ void gibbs_update_step_for_graphical_model_parameters (
614614
const int num_groups,
615615
const arma::imat group_indices,
616616
double difference_scale,
617-
dqrng::xoshiro256plus& rng,
617+
SafeRNG& rng,
618618
arma::mat& inclusion_probability
619619
) {
620620

@@ -680,7 +680,7 @@ SamplerOutput run_gibbs_sampler_for_bgmCompare(
680680
const arma::imat& group_indices,//new
681681
const arma::imat& interaction_index_matrix,//new
682682
arma::mat inclusion_probability,//new
683-
dqrng::xoshiro256plus& rng
683+
SafeRNG& rng
684684
) {
685685
// --- Setup: dimensions and storage structures
686686
const int num_variables = observations.n_cols;

src/bgmCompare_sampler.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#pragma once
2-
#include <dqrng.h>
3-
#include <xoshiro.h>
42
#include <RcppArmadillo.h>
53
#include "sampler_output.h"
64

7-
SamplerOutput run_gibbs_sampler_for_bgmCompare(
5+
struct SafeRNG;
6+
7+
Rcpp::List run_gibbs_sampler_for_bgmCompare(
88
int chain_id,
99
arma::imat observations,
1010
const int num_groups,
@@ -36,5 +36,5 @@ SamplerOutput run_gibbs_sampler_for_bgmCompare(
3636
const arma::imat& group_indices,
3737
const arma::imat& interaction_index_matrix,
3838
arma::mat inclusion_probability,
39-
dqrng::xoshiro256plus& rng
39+
SafeRNG& rng
4040
);

src/bgm_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ inline void initialise_graph(
5050
const arma::mat& incl_prob,
5151
arma::mat& rest,
5252
const arma::imat& X,
53-
dqrng::xoshiro256plus& rng
53+
SafeRNG& rng
5454
) {
5555
int V = indicator.n_rows;
5656
for (int i = 0; i < V-1; ++i) {

src/bgm_parallel.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,18 @@
11
// [[Rcpp::depends(RcppParallel, RcppArmadillo, dqrng)]]
22
#include <RcppParallel.h>
33
#include <RcppArmadillo.h>
4-
#include <dqrng.h>
5-
#include <xoshiro.h>
64
#include "bgm_sampler.h"
75
#include <tbb/global_control.h>
86
#include <vector>
97
#include <string>
8+
#include "rng_utils.h"
109

1110
using namespace Rcpp;
1211
using namespace RcppParallel;
1312

1413
// -----------------------------------------------------------------------------
1514
// Wrapper to silence Clang warning
1615
// -----------------------------------------------------------------------------
17-
struct SafeRNG {
18-
dqrng::xoshiro256plus eng;
19-
20-
SafeRNG() : eng() {}
21-
SafeRNG(const dqrng::xoshiro256plus& other) : eng(other) {}
22-
SafeRNG& operator=(const dqrng::xoshiro256plus& other) {
23-
eng = other;
24-
return *this;
25-
}
26-
};
2716

2817
// -----------------------------------------------------------------------------
2918
// Result struct
@@ -146,7 +135,7 @@ struct GibbsChainRunner : public Worker {
146135
out.error = false;
147136

148137
try {
149-
dqrng::xoshiro256plus rng = chain_rngs[i].eng;
138+
SafeRNG rng = chain_rngs[i];
150139

151140
Rcpp::List result = run_gibbs_sampler_for_bgm(
152141
out.chain_id,
@@ -237,11 +226,8 @@ Rcpp::List run_bgm_parallel(
237226

238227
// Prepare one independent RNG per chain via jump()
239228
std::vector<SafeRNG> chain_rngs(num_chains);
240-
chain_rngs[0].eng = dqrng::xoshiro256plus(seed);
241-
242-
for (int c = 1; c < num_chains; ++c) {
243-
chain_rngs[c].eng = chain_rngs[c-1].eng;
244-
chain_rngs[c].eng.jump();
229+
for (int c = 0; c < num_chains; ++c) {
230+
chain_rngs[c] = SafeRNG(seed + c);
245231
}
246232

247233
GibbsChainRunner worker(

src/bgm_sampler.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void impute_missing_values_for_graphical_model (
5050
const arma::uvec& is_ordinal_variable,
5151
const arma::ivec& reference_category,
5252
arma::imat& sufficient_pairwise,
53-
dqrng::xoshiro256plus& rng
53+
SafeRNG& rng
5454
) {
5555
const int num_variables = observations.n_cols;
5656
const int num_missings = missing_index.n_rows;
@@ -193,7 +193,7 @@ double find_reasonable_initial_step_size(
193193
const double interaction_scale,
194194
const double target_acceptance,
195195
const arma::imat& sufficient_pairwise,
196-
dqrng::xoshiro256plus& rng
196+
SafeRNG& rng
197197
) {
198198
arma::vec theta = vectorize_model_parameters(
199199
main_effects, pairwise_effects, inclusion_indicator,
@@ -278,7 +278,7 @@ void update_main_effects_with_metropolis (
278278
arma::mat& proposal_sd_main,
279279
RWMAdaptationController& adapter,
280280
const int iteration,
281-
dqrng::xoshiro256plus& rng
281+
SafeRNG& rng
282282
) {
283283
const int num_vars = observations.n_cols;
284284
arma::umat index_mask_main = arma::ones<arma::umat>(proposal_sd_main.n_rows, proposal_sd_main.n_cols);
@@ -381,7 +381,7 @@ void update_pairwise_effects_with_metropolis (
381381
const arma::ivec& reference_category,
382382
const int iteration,
383383
const arma::imat& sufficient_pairwise,
384-
dqrng::xoshiro256plus& rng
384+
SafeRNG& rng
385385
) {
386386
arma::mat accept_prob_pairwise = arma::zeros<arma::mat>(num_variables, num_variables);
387387
arma::umat index_mask_pairwise = arma::zeros<arma::umat>(num_variables, num_variables);
@@ -478,7 +478,7 @@ void update_parameters_with_hmc(
478478
HMCAdaptationController& adapt,
479479
const bool learn_mass_matrix,
480480
const bool selection,
481-
dqrng::xoshiro256plus& rng
481+
SafeRNG& rng
482482
) {
483483
arma::vec current_state = vectorize_model_parameters(
484484
main_effects, pairwise_effects, inclusion_indicator,
@@ -587,7 +587,7 @@ SamplerResult update_parameters_with_nuts(
587587
HMCAdaptationController& adapt,
588588
const bool learn_mass_matrix,
589589
const bool selection,
590-
dqrng::xoshiro256plus& rng
590+
SafeRNG& rng
591591
) {
592592
arma::vec current_state = vectorize_model_parameters(
593593
main_effects, pairwise_effects, inclusion_indicator,
@@ -668,7 +668,7 @@ void tune_pairwise_proposal_sd(
668668
const arma::imat& sufficient_pairwise,
669669
int iteration,
670670
const WarmupSchedule& sched,
671-
dqrng::xoshiro256plus& rng,
671+
SafeRNG& rng,
672672
double target_accept = 0.44,
673673
double rm_decay = 0.75
674674
)
@@ -762,7 +762,7 @@ void update_indicator_interaction_pair_with_metropolis (
762762
const arma::uvec& is_ordinal_variable,
763763
const arma::ivec& reference_category,
764764
const arma::imat& sufficient_pairwise,
765-
dqrng::xoshiro256plus& rng
765+
SafeRNG& rng
766766
) {
767767
for (int cntr = 0; cntr < num_interactions; cntr++) {
768768
const int variable1 = index(cntr, 1);
@@ -912,7 +912,7 @@ void gibbs_update_step_for_graphical_model_parameters (
912912
arma::ivec& treedepth_samples,
913913
arma::ivec& divergent_samples,
914914
arma::vec& energy_samples,
915-
dqrng::xoshiro256plus& rng
915+
SafeRNG& rng
916916
) {
917917

918918
// Step 0: Initialise random graph structure when edge_selection = TRUE
@@ -1071,7 +1071,7 @@ Rcpp::List run_gibbs_sampler_for_bgm(
10711071
const int hmc_num_leapfrogs,
10721072
const int nuts_max_depth,
10731073
const bool learn_mass_matrix,
1074-
dqrng::xoshiro256plus& rng
1074+
SafeRNG& rng
10751075
) {
10761076
// --- Setup: dimensions and storage structures
10771077
const int num_variables = observations.n_cols;
@@ -1176,8 +1176,8 @@ Rcpp::List run_gibbs_sampler_for_bgm(
11761176
for (int iteration = 0; iteration < total_iter; iteration++) {
11771177
if (iteration % print_every == 0) {
11781178
tbb::mutex::scoped_lock lock(get_print_mutex());
1179-
//Rcpp::Rcout
11801179
std::cout
1180+
// Rcpp::Rcout
11811181
<< "[bgm] chain " << chain_id
11821182
<< " iteration " << iteration
11831183
<< " / " << total_iter

src/bgm_sampler.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#pragma once
2-
#include <dqrng.h>
3-
#include <xoshiro.h>
42
#include <RcppArmadillo.h>
53

4+
// forward declaration
5+
struct SafeRNG;
6+
67
Rcpp::List run_gibbs_sampler_for_bgm(
78
int chain_id,
89
arma::imat observations,
@@ -33,5 +34,5 @@ Rcpp::List run_gibbs_sampler_for_bgm(
3334
const int hmc_num_leapfrogs,
3435
const int nuts_max_depth,
3536
const bool learn_mass_matrix,
36-
dqrng::xoshiro256plus& rng
37+
SafeRNG& rng
3738
);

src/gibbs_functions_edge_prior.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ arma::uvec table_cpp(arma::uvec x) {
3333
arma::mat add_row_col_block_prob_matrix(arma::mat X,
3434
double beta_alpha,
3535
double beta_beta,
36-
dqrng::xoshiro256plus& rng) {
36+
SafeRNG& rng) {
3737
arma::uword dim = X.n_rows;
3838
arma::mat Y(dim+1,dim+1,arma::fill::zeros);
3939

@@ -159,7 +159,7 @@ inline void update_sumG(double &sumG,
159159
// Sample the cluster assignment in sample_block_allocations_mfm_sbm()
160160
// ----------------------------------------------------------------------------|
161161
arma::uword sample_cluster(arma::vec cluster_prob,
162-
dqrng::xoshiro256plus& rng) {
162+
SafeRNG& rng) {
163163
arma::vec cum_prob = arma::cumsum(cluster_prob);
164164
double u = runif(rng) * arma::max(cum_prob);
165165

@@ -182,7 +182,7 @@ arma::uvec block_allocations_mfm_sbm(arma::uvec cluster_assign,
182182
arma::uword dirichlet_alpha,
183183
double beta_bernoulli_alpha,
184184
double beta_bernoulli_beta,
185-
dqrng::xoshiro256plus& rng) {
185+
SafeRNG& rng) {
186186
arma::uword old;
187187
arma::uword cluster;
188188
arma::uword no_clusters;
@@ -320,7 +320,7 @@ arma::mat block_probs_mfm_sbm(arma::uvec cluster_assign,
320320
arma::uword no_variables,
321321
double beta_bernoulli_alpha,
322322
double beta_bernoulli_beta,
323-
dqrng::xoshiro256plus& rng) {
323+
SafeRNG& rng) {
324324

325325
arma::uvec cluster_size = table_cpp(cluster_assign);
326326
arma::uword no_clusters = cluster_size.n_elem;

0 commit comments

Comments
 (0)