Skip to content

Commit bac0991

Browse files
authored
parallel progress bar and interrupt (#56)
* parallel progress bar and interrupt * let's rebase first * fix rebase * almost done * cleanup * fix rebase again
1 parent e05a3ff commit bac0991

15 files changed

+641
-48
lines changed

R/RcppExports.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

4-
run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs) {
5-
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs)
4+
run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) {
5+
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type)
66
}
77

8-
run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) {
9-
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed)
8+
run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) {
9+
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)
1010
}
1111

1212
get_explog_switch <- function() {

R/bgm.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ bgm = function(
237237
dirichlet_alpha = 1,
238238
lambda = 1,
239239
na_action = c("listwise", "impute"),
240-
display_progress = TRUE,
240+
display_progress = c("per-chain", "total", "none"),
241241
update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"),
242242
target_accept,
243243
hmc_num_leapfrogs = 100,
@@ -324,7 +324,7 @@ bgm = function(
324324
"."))
325325

326326
#Check display_progress ------------------------------------------------------
327-
display_progress = check_logical(display_progress, "display_progress")
327+
progress_type = progress_type_from_display_progress(display_progress)
328328

329329
#Format the data input -------------------------------------------------------
330330
data = reformat_data(x = x,
@@ -428,7 +428,7 @@ bgm = function(
428428
target_accept = target_accept, pairwise_stats = pairwise_stats,
429429
hmc_num_leapfrogs = hmc_num_leapfrogs, nuts_max_depth = nuts_max_depth,
430430
learn_mass_matrix = learn_mass_matrix, num_chains = chains,
431-
nThreads = cores, seed = seed
431+
nThreads = cores, seed = seed, progress_type = progress_type
432432
)
433433

434434
# Main output handler in the wrapper function
@@ -457,5 +457,9 @@ bgm = function(
457457
output$nuts_diag = nuts_diag
458458
}
459459

460+
userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
461+
if (userInterrupt)
462+
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
463+
460464
return(output)
461465
}

R/bgmCompare.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ bgmCompare = function(
122122
iter = 1e3,
123123
burnin = 1e3,
124124
na_action = c("listwise", "impute"),
125-
display_progress = TRUE,
125+
display_progress = c("per-chain", "total", "none"),
126126
update_method = c("nuts", "adaptive-metropolis", "hamiltonian-mc"),
127127
target_accept,
128128
hmc_num_leapfrogs = 100,
@@ -202,7 +202,8 @@ bgmCompare = function(
202202
}
203203

204204
# Check display_progress
205-
display_progress = check_logical(display_progress, "display_progress")
205+
progress_type = progress_type_from_display_progress(display_progress)
206+
206207

207208
## Format data
208209
data = compare_reformat_data(
@@ -340,7 +341,8 @@ bgmCompare = function(
340341
inclusion_probability = model$inclusion_probability_difference,
341342
num_chains = chains, nThreads = cores,
342343
seed = seed,
343-
update_method = update_method, hmc_num_leapfrogs = hmc_num_leapfrogs
344+
update_method = update_method, hmc_num_leapfrogs = hmc_num_leapfrogs,
345+
progress_type = progress_type
344346
)
345347

346348
# Main output handler in the wrapper function

R/function_input_utils.R

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,4 +463,15 @@ check_compare_model = function(
463463
inclusion_probability_difference = inclusion_probability_difference
464464
)
465465
)
466-
}
466+
}
467+
468+
progress_type_from_display_progress <- function(display_progress = c("per-chain", "total", "none")) {
469+
if (is.logical(display_progress) && length(display_progress) == 1) {
470+
if (is.na(display_progress))
471+
stop("The display_progress argument must be a single logical value, but not NA.")
472+
display_progress = if (display_progress) "per-chain" else "none"
473+
} else {
474+
display_progress = match.arg(display_progress)
475+
}
476+
return(if (display_progress == "per-chain") 2L else if (display_progress == "total") 1L else 0L)
477+
}

src/Makevars.win

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CXX_STD = CXX14
1+
CXX_STD = CXX20
22

33
PKG_CPPFLAGS = \
44
$(shell "$(R_HOME)\bin\Rscript.exe" -e "cat(RcppParallel::CxxFlags())")

src/RcppExports.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1212
#endif
1313

1414
// run_bgmCompare_parallel
15-
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector<arma::imat>& counts_per_category, const std::vector<arma::imat>& blume_capel_stats, const std::vector<arma::mat>& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs);
16-
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP) {
15+
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector<arma::imat>& counts_per_category, const std::vector<arma::imat>& blume_capel_stats, const std::vector<arma::mat>& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs, int progress_type);
16+
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP progress_typeSEXP) {
1717
BEGIN_RCPP
1818
Rcpp::RObject rcpp_result_gen;
1919
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -52,13 +52,14 @@ BEGIN_RCPP
5252
Rcpp::traits::input_parameter< int >::type seed(seedSEXP);
5353
Rcpp::traits::input_parameter< const std::string& >::type update_method(update_methodSEXP);
5454
Rcpp::traits::input_parameter< int >::type hmc_num_leapfrogs(hmc_num_leapfrogsSEXP);
55-
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs));
55+
Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP);
56+
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type));
5657
return rcpp_result_gen;
5758
END_RCPP
5859
}
5960
// run_bgm_parallel
60-
Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed);
61-
RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP) {
61+
Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int burnin, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, uint64_t seed, int progress_type);
62+
RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) {
6263
BEGIN_RCPP
6364
Rcpp::RObject rcpp_result_gen;
6465
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -93,7 +94,8 @@ BEGIN_RCPP
9394
Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP);
9495
Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP);
9596
Rcpp::traits::input_parameter< uint64_t >::type seed(seedSEXP);
96-
rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed));
97+
Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP);
98+
rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type));
9799
return rcpp_result_gen;
98100
END_RCPP
99101
}
@@ -179,8 +181,8 @@ END_RCPP
179181
}
180182

181183
static const R_CallMethodDef CallEntries[] = {
182-
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 35},
183-
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 31},
184+
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36},
185+
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 32},
184186
{"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0},
185187
{"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1},
186188
{"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1},

0 commit comments

Comments
 (0)