Skip to content

Commit acf3468

Browse files
Implement adaptive mala and fisher mala for thresholds and interactions.
1 parent 6c3d937 commit acf3468

File tree

6 files changed

+706
-310
lines changed

6 files changed

+706
-310
lines changed

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interact
99
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
1010
}
1111

12-
run_gibbs_sampler_for_bgm <- function(observations, num_categories, interaction_scale, edge_prior, theta, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, Index, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main = FALSE, save_pairwise = FALSE, save_indicator = FALSE, display_progress = FALSE, edge_selection = TRUE, use_mala = FALSE) {
13-
.Call(`_bgms_run_gibbs_sampler_for_bgm`, observations, num_categories, interaction_scale, edge_prior, theta, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, Index, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main, save_pairwise, save_indicator, display_progress, edge_selection, use_mala)
12+
run_gibbs_sampler_for_bgm <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main = FALSE, save_pairwise = FALSE, save_indicator = FALSE, display_progress = FALSE, edge_selection = TRUE, update_method = "adaptive-metropolis") {
13+
.Call(`_bgms_run_gibbs_sampler_for_bgm`, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main, save_pairwise, save_indicator, display_progress, edge_selection, update_method)
1414
}
1515

1616
compare_anova_gibbs_sampler <- function(observations, main_effect_indices, pairwise_effect_indices, projection, num_categories, num_groups, group_indices, interaction_scale, pairwise_difference_scale, main_difference_scale, pairwise_difference_prior, main_difference_prior, inclusion_probability_difference, pairwise_beta_bernoulli_alpha, pairwise_beta_bernoulli_beta, main_beta_bernoulli_alpha, main_beta_bernoulli_beta, Index, iter, burnin, num_obs_categories, sufficient_blume_capel, prior_threshold_alpha, prior_threshold_beta, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, independent_thresholds, save_main = FALSE, save_pairwise = FALSE, save_indicator = FALSE, display_progress = FALSE, difference_selection = TRUE) {

R/bgm.R

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@
152152
#' @param display_progress Should the function show a progress bar
153153
#' (\code{display_progress = TRUE})? Or not (\code{display_progress = FALSE})?
154154
#' The default is \code{TRUE}.
155+
#' @param update_method Character. Specifies how the MCMC sampler updates the threshold
156+
#' and interaction parameters:
157+
#' \describe{
158+
#' \item{"adaptive-metropolis"}{Uses adaptive Metropolis-Hastings for both thresholds and interactions.}
159+
#' \item{"adaptive-mala"}{Uses Fisher-preconditioned MALA for thresholds and standard MALA for interactions and indicators.}
160+
#' \item{"fisher-mala"}{Uses Fisher-preconditioned MALA for both thresholds and interactions.}
161+
#' }
162+
#' Defaults to \code{"adaptive-metropolis"}.
155163
#'
156164
#' @return If \code{save = FALSE} (the default), the result is a list of class
157165
#' ``bgms'' containing the following matrices with model-averaged quantities:
@@ -317,7 +325,8 @@ bgm = function(x,
317325
save_pairwise = FALSE,
318326
save_indicator = FALSE,
319327
display_progress = TRUE,
320-
mala = FALSE) {
328+
update_method = c("adaptive-metropolis", "adaptive-mala", "fisher-mala")
329+
) {
321330

322331
# Deprecation warning for save parameter
323332
if(hasArg(save)) {
@@ -333,6 +342,11 @@ bgm = function(x,
333342
save_indicator = check_logical(save_indicator, "save_indicator")
334343

335344

345+
# Check update method
346+
update_method_input = update_method
347+
update_method = match.arg(update_method)
348+
349+
336350
#Check data input ------------------------------------------------------------
337351
if(!inherits(x, what = "matrix") && !inherits(x, what = "data.frame"))
338352
stop("The input x needs to be a matrix or dataframe.")
@@ -368,7 +382,7 @@ bgm = function(x,
368382
reference_category = model$reference_category
369383
edge_selection = model$edge_selection
370384
edge_prior = model$edge_prior
371-
theta = model$theta
385+
inclusion_probability = model$inclusion_probability
372386

373387
#Check Gibbs input -----------------------------------------------------------
374388
if(abs(iter - round(iter)) > .Machine$double.eps)
@@ -446,27 +460,28 @@ bgm = function(x,
446460
}
447461
}
448462

449-
# Index vector used to sample interactions in a random order -----------------
450-
Index = matrix(0,
463+
# Index matrix used in the c++ functions ------------------------------------
464+
interaction_index_matrix = matrix(0,
451465
nrow = num_variables * (num_variables - 1) / 2,
452466
ncol = 3)
453467
cntr = 0
454468
for(variable1 in 1:(num_variables - 1)) {
455469
for(variable2 in (variable1 + 1):num_variables) {
456470
cntr = cntr + 1
457-
Index[cntr, 1] = cntr
458-
Index[cntr, 2] = variable1 - 1
459-
Index[cntr, 3] = variable2 - 1
471+
interaction_index_matrix[cntr, 1] = cntr
472+
interaction_index_matrix[cntr, 2] = variable1 - 1
473+
interaction_index_matrix[cntr, 3] = variable2 - 1
460474
}
461475
}
462476

463477
# Call the Rcpp function
464478
out = run_gibbs_sampler_for_bgm (
465479
observations = x, num_categories = num_categories,
466480
interaction_scale = interaction_scale, edge_prior = edge_prior,
467-
theta = theta, beta_bernoulli_alpha = beta_bernoulli_alpha,
481+
inclusion_probability = inclusion_probability, beta_bernoulli_alpha = beta_bernoulli_alpha,
468482
beta_bernoulli_beta = beta_bernoulli_beta,
469-
dirichlet_alpha = dirichlet_alpha, lambda = lambda, Index = Index,
483+
dirichlet_alpha = dirichlet_alpha, lambda = lambda,
484+
interaction_index_matrix = interaction_index_matrix,
470485
iter = iter, burnin = burnin, num_obs_categories = num_obs_categories,
471486
sufficient_blume_capel = sufficient_blume_capel,
472487
threshold_alpha = threshold_alpha, threshold_beta = threshold_beta,
@@ -475,7 +490,7 @@ bgm = function(x,
475490
reference_category = reference_category, save_main = save_main,
476491
save_pairwise = save_pairwise, save_indicator = save_indicator,
477492
display_progress = display_progress, edge_selection = edge_selection,
478-
use_mala = mala
493+
update_method = update_method
479494
)
480495

481496
# Main output handler in the wrapper function
@@ -488,7 +503,7 @@ bgm = function(x,
488503
burnin = burnin, interaction_scale = interaction_scale,
489504
threshold_alpha = threshold_alpha, threshold_beta = threshold_beta,
490505
na_action = na_action, na_impute = na_impute,
491-
edge_selection = edge_selection, edge_prior = edge_prior, theta = theta,
506+
edge_selection = edge_selection, edge_prior = edge_prior, inclusion_probability = inclusion_probability,
492507
beta_bernoulli_alpha = beta_bernoulli_alpha,
493508
beta_bernoulli_beta = beta_bernoulli_beta,
494509
dirichlet_alpha = dirichlet_alpha, lambda = lambda,

R/function_input_utils.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ check_model = function(x,
226226
reference_category = reference_category,
227227
edge_selection = edge_selection,
228228
edge_prior = edge_prior,
229-
theta = theta))
229+
inclusion_probability = theta))
230230
}
231231

232232
check_compare_model = function(x,

R/output_utils.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
prepare_output_bgm = function (
44
out, x, num_categories, iter, data_columnnames, is_ordinal_variable,
55
save_options, burnin, interaction_scale, threshold_alpha, threshold_beta,
6-
na_action, na_impute, edge_selection, edge_prior, theta,
6+
na_action, na_impute, edge_selection, edge_prior, inclusion_probability,
77
beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda,
88
variable_type) {
99

@@ -15,7 +15,7 @@ prepare_output_bgm = function (
1515
burnin = burnin, interaction_scale = interaction_scale,
1616
threshold_alpha = threshold_alpha, threshold_beta = threshold_beta,
1717
edge_selection = edge_selection, edge_prior = edge_prior,
18-
inclusion_probability = theta, beta_bernoulli_alpha = beta_bernoulli_alpha,
18+
inclusion_probability = inclusion_probability, beta_bernoulli_alpha = beta_bernoulli_alpha,
1919
beta_bernoulli_beta = beta_bernoulli_beta,
2020
dirichlet_alpha = dirichlet_alpha, lambda = lambda, na_action = na_action,
2121
save = save, version = packageVersion("bgms")

src/RcppExports.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,21 @@ BEGIN_RCPP
4646
END_RCPP
4747
}
4848
// run_gibbs_sampler_for_bgm
49-
List run_gibbs_sampler_for_bgm(arma::imat& observations, const arma::ivec& num_categories, const double interaction_scale, const String& edge_prior, arma::mat& theta, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double dirichlet_alpha, const double lambda, const arma::imat& Index, const int iter, const int burnin, arma::imat& num_obs_categories, arma::imat& sufficient_blume_capel, const double threshold_alpha, const double threshold_beta, const bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, const bool save_main, const bool save_pairwise, const bool save_indicator, const bool display_progress, bool edge_selection, bool use_mala);
50-
RcppExport SEXP _bgms_run_gibbs_sampler_for_bgm(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP interaction_scaleSEXP, SEXP edge_priorSEXP, SEXP thetaSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP IndexSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP threshold_alphaSEXP, SEXP threshold_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP save_mainSEXP, SEXP save_pairwiseSEXP, SEXP save_indicatorSEXP, SEXP display_progressSEXP, SEXP edge_selectionSEXP, SEXP use_malaSEXP) {
49+
List run_gibbs_sampler_for_bgm(arma::imat& observations, const arma::ivec& num_categories, const double interaction_scale, const String& edge_prior, arma::mat& inclusion_probability, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double dirichlet_alpha, const double lambda, const arma::imat& interaction_index_matrix, const int iter, const int burnin, arma::imat& num_obs_categories, arma::imat& sufficient_blume_capel, const double threshold_alpha, const double threshold_beta, const bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, const bool save_main, const bool save_pairwise, const bool save_indicator, const bool display_progress, bool edge_selection, const std::string& update_method);
50+
RcppExport SEXP _bgms_run_gibbs_sampler_for_bgm(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP interaction_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP threshold_alphaSEXP, SEXP threshold_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP save_mainSEXP, SEXP save_pairwiseSEXP, SEXP save_indicatorSEXP, SEXP display_progressSEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP) {
5151
BEGIN_RCPP
5252
Rcpp::RObject rcpp_result_gen;
5353
Rcpp::RNGScope rcpp_rngScope_gen;
5454
Rcpp::traits::input_parameter< arma::imat& >::type observations(observationsSEXP);
5555
Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP);
5656
Rcpp::traits::input_parameter< const double >::type interaction_scale(interaction_scaleSEXP);
5757
Rcpp::traits::input_parameter< const String& >::type edge_prior(edge_priorSEXP);
58-
Rcpp::traits::input_parameter< arma::mat& >::type theta(thetaSEXP);
58+
Rcpp::traits::input_parameter< arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP);
5959
Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP);
6060
Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP);
6161
Rcpp::traits::input_parameter< const double >::type dirichlet_alpha(dirichlet_alphaSEXP);
6262
Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP);
63-
Rcpp::traits::input_parameter< const arma::imat& >::type Index(IndexSEXP);
63+
Rcpp::traits::input_parameter< const arma::imat& >::type interaction_index_matrix(interaction_index_matrixSEXP);
6464
Rcpp::traits::input_parameter< const int >::type iter(iterSEXP);
6565
Rcpp::traits::input_parameter< const int >::type burnin(burninSEXP);
6666
Rcpp::traits::input_parameter< arma::imat& >::type num_obs_categories(num_obs_categoriesSEXP);
@@ -76,8 +76,8 @@ BEGIN_RCPP
7676
Rcpp::traits::input_parameter< const bool >::type save_indicator(save_indicatorSEXP);
7777
Rcpp::traits::input_parameter< const bool >::type display_progress(display_progressSEXP);
7878
Rcpp::traits::input_parameter< bool >::type edge_selection(edge_selectionSEXP);
79-
Rcpp::traits::input_parameter< bool >::type use_mala(use_malaSEXP);
80-
rcpp_result_gen = Rcpp::wrap(run_gibbs_sampler_for_bgm(observations, num_categories, interaction_scale, edge_prior, theta, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, Index, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main, save_pairwise, save_indicator, display_progress, edge_selection, use_mala));
79+
Rcpp::traits::input_parameter< const std::string& >::type update_method(update_methodSEXP);
80+
rcpp_result_gen = Rcpp::wrap(run_gibbs_sampler_for_bgm(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main, save_pairwise, save_indicator, display_progress, edge_selection, update_method));
8181
return rcpp_result_gen;
8282
END_RCPP
8383
}

0 commit comments

Comments
 (0)