Skip to content

Commit 921297e

Browse files
Init dqrng
1 parent 1c720e2 commit 921297e

25 files changed

+654
-469
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ Description: Bayesian variable selection methods for analyzing the structure of
2020
License: GPL (>= 2)
2121
URL: https://maartenmarsman.github.io/bgms/
2222
BugReports: https://github.com/MaartenMarsman/bgms/issues
23-
Imports: Rcpp (>= 1.0.7), RcppParallel, Rdpack, methods, coda
23+
Imports: Rcpp (>= 1.0.7), RcppParallel, Rdpack, methods, coda, dqrng
2424
RdMacros: Rdpack
25-
LinkingTo: Rcpp, RcppProgress, RcppArmadillo, RcppParallel
25+
LinkingTo: Rcpp, RcppProgress, RcppArmadillo, RcppParallel, dqrng
2626
RoxygenNote: 7.3.2
2727
Depends:
2828
R (>= 2.10)

R/RcppExports.R

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,8 @@ run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories
55
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads)
66
}
77

8-
run_gibbs_sampler_for_bgmCompare <- function(chain_id, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability) {
9-
.Call(`_bgms_run_gibbs_sampler_for_bgmCompare`, chain_id, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability)
10-
}
11-
12-
run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads) {
13-
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads)
14-
}
15-
16-
run_gibbs_sampler_for_bgm <- function(chain_id, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix) {
17-
.Call(`_bgms_run_gibbs_sampler_for_bgm`, chain_id, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix)
8+
run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) {
9+
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed)
1810
}
1911

2012
sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, iter) {

R/bgm.R

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ bgm = function(
336336
nuts_max_depth = 10,
337337
learn_mass_matrix = FALSE,
338338
chains = 4,
339-
cores = parallel::detectCores()
339+
cores = parallel::detectCores(),
340+
seed = NULL
340341
) {
341342
# Check update method
342343
update_method_input = update_method
@@ -488,6 +489,18 @@ bgm = function(
488489
}
489490
}
490491

492+
#Setting the seed
493+
if (missing(seed) || is.null(seed)) {
494+
# Draw a random seed if none provided
495+
seed <- sample.int(.Machine$integer.max, 1)
496+
}
497+
498+
if (!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) {
499+
stop("Argument 'seed' must be a single non-negative integer.")
500+
}
501+
502+
seed <- as.integer(seed)
503+
491504
out = run_bgm_parallel(
492505
observations = x, num_categories = num_categories,
493506
interaction_scale = interaction_scale, edge_prior = edge_prior,
@@ -507,7 +520,7 @@ bgm = function(
507520
target_accept = target_accept, sufficient_pairwise = sufficient_pairwise,
508521
hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth,
509522
learn_mass_matrix = learn_mass_matrix, num_chains = chains,
510-
nThreads = cores
523+
nThreads = cores, seed = seed
511524
)
512525

513526
# Main output handler in the wrapper function

R/bgmCompare2.R

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ bgmCompare2 = function(
2424
nuts_max_depth = 10,
2525
learn_mass_matrix = FALSE,
2626
chains = 4,
27-
cores = parallel::detectCores()
27+
cores = parallel::detectCores(),
28+
seed = NULL
2829
) {
2930
# Check update method
3031
update_method_input = update_method
@@ -190,6 +191,15 @@ bgmCompare2 = function(
190191
projection = matrix(projection, ncol = 1) / sqrt(2)
191192
}
192193

194+
if (!is.null(seed)) {
195+
if (!is.numeric(seed) || any(is.na(seed)) || any(seed < 0)) {
196+
stop("Argument 'seed' must be a non-negative integer or vector of non-negative integers.")
197+
}
198+
# Force to integer type
199+
seed <- as.integer(seed)
200+
dqrng::dqset.seed(seed)
201+
}
202+
193203
# Call the Rcpp function
194204
out = run_bgmCompare_parallel(
195205
observations = observations,

src/Makevars

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CXX_STD = CXX11
1+
CXX_STD = CXX14
22

33
## Pull in the include-paths for RcppParallel
44
PKG_CPPFLAGS = \

src/Makevars.win

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CXX_STD = CXX11
1+
CXX_STD = CXX14
22

33
PKG_CPPFLAGS = \
44
$(shell "$(R_HOME)\bin\Rscript.exe" -e "cat(RcppParallel::CxxFlags())")

0 commit comments

Comments
 (0)