Skip to content

Commit c925d92

Browse files
sekulovskinNikola
andauthored
two sets of shape hyperparameters for the Beta hyperprior in the SBM (#68)
* two sets of shape hyperparameters for the Beta hyperprior in the SBM * Update documentation for the extra shape parameters * update documentation * adress review comments --------- Co-authored-by: Nikola <[email protected]>
1 parent 79ccb11 commit c925d92

File tree

11 files changed

+188
-60
lines changed

11 files changed

+188
-60
lines changed

R/RcppExports.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ run_bgmCompare_parallel <- function(observations, num_groups, counts_per_categor
55
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type)
66
}
77

8-
run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) {
9-
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)
8+
run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type) {
9+
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)
1010
}
1111

1212
get_explog_switch <- function() {

R/bgm.R

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,14 @@
207207
#'
208208
#' @param beta_bernoulli_alpha,beta_bernoulli_beta Double. Shape parameters
209209
#' for the beta distribution in the Beta–Bernoulli and the Stochastic-Block
210-
#' priors. Must be positive. Defaults: \code{beta_bernoulli_alpha = 1} and
211-
#' \code{beta_bernoulli_beta = 1}.
210+
#' priors. Must be positive. For the Stochastic-Block prior these are the shape
211+
#' parameters for the within-cluster edge inclusion probabilities.
212+
#' Defaults: \code{beta_bernoulli_alpha = 1} and \code{beta_bernoulli_beta = 1}.
213+
#'
214+
#' @param beta_bernoulli_alpha_between,beta_bernoulli_beta_between Double.
215+
#' Shape parameters for the between-cluster edge inclusion probabilities in the
216+
#' Stochastic-Block prior. Must be positive.
217+
#' Default: \code{beta_bernoulli_alpha_between = 1} and \code{beta_bernoulli_beta_between = 1}
212218
#'
213219
#' @param dirichlet_alpha Double. Concentration parameter of the Dirichlet
214220
#' prior on block assignments (used with the Stochastic Block model).
@@ -359,6 +365,8 @@ bgm = function(
359365
inclusion_probability = 0.5,
360366
beta_bernoulli_alpha = 1,
361367
beta_bernoulli_beta = 1,
368+
beta_bernoulli_alpha_between = 1,
369+
beta_bernoulli_beta_between = 1,
362370
dirichlet_alpha = 1,
363371
lambda = 1,
364372
na_action = c("listwise", "impute"),
@@ -418,7 +426,7 @@ bgm = function(
418426
} else if(update_method == "hamiltonian-mc") {
419427
target_accept = 0.65
420428
} else if(update_method == "nuts") {
421-
target_accept = 0.80
429+
target_accept = 0.60
422430
}
423431
}
424432

@@ -444,9 +452,21 @@ bgm = function(
444452
inclusion_probability = inclusion_probability,
445453
beta_bernoulli_alpha = beta_bernoulli_alpha,
446454
beta_bernoulli_beta = beta_bernoulli_beta,
455+
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
456+
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
447457
dirichlet_alpha = dirichlet_alpha,
448458
lambda = lambda)
449459

460+
# check hyperparameters input
461+
# If user left them NULL, pass -1 to C++ (means: ignore between prior)
462+
if (is.null(beta_bernoulli_alpha_between) && is.null(beta_bernoulli_beta_between)) {
463+
beta_bernoulli_alpha_between <- -1.0
464+
beta_bernoulli_beta_between <- -1.0
465+
} else if (is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) {
466+
stop("If you wish to specify different between and within cluster probabilites,
467+
provide both beta_bernoulli_alpha_between and beta_bernoulli_beta_between,
468+
otherwise leave both NULL.")
469+
}
450470
# ----------------------------------------------------------------------------
451471
# The vector variable_type is now coded as boolean.
452472
# Ordinal (variable_bool == TRUE) or Blume-Capel (variable_bool == FALSE)
@@ -572,6 +592,8 @@ bgm = function(
572592
inclusion_probability = inclusion_probability,
573593
beta_bernoulli_alpha = beta_bernoulli_alpha,
574594
beta_bernoulli_beta = beta_bernoulli_beta,
595+
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
596+
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
575597
dirichlet_alpha = dirichlet_alpha, lambda = lambda,
576598
interaction_index_matrix = interaction_index_matrix, iter = iter,
577599
warmup = warmup, counts_per_category = counts_per_category,
@@ -603,6 +625,7 @@ bgm = function(
603625
na_action = na_action, na_impute = na_impute,
604626
edge_selection = edge_selection, edge_prior = edge_prior, inclusion_probability = inclusion_probability,
605627
beta_bernoulli_alpha = beta_bernoulli_alpha, beta_bernoulli_beta = beta_bernoulli_beta,
628+
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, beta_bernoulli_beta_between = beta_bernoulli_beta_between,
606629
dirichlet_alpha = dirichlet_alpha, lambda = lambda,
607630
variable_type = variable_type,
608631
update_method = update_method,
@@ -634,6 +657,8 @@ bgm = function(
634657
edge_selection = edge_selection, edge_prior = edge_prior, inclusion_probability = inclusion_probability,
635658
beta_bernoulli_alpha = beta_bernoulli_alpha,
636659
beta_bernoulli_beta = beta_bernoulli_beta,
660+
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
661+
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
637662
dirichlet_alpha = dirichlet_alpha, lambda = lambda,
638663
variable_type = variable_type,
639664
update_method = update_method,

R/function_input_utils.R

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ check_model = function(x,
3131
inclusion_probability = 0.5,
3232
beta_bernoulli_alpha = 1,
3333
beta_bernoulli_beta = 1,
34+
beta_bernoulli_alpha_between = 1,
35+
beta_bernoulli_beta_between = 1,
3436
dirichlet_alpha = dirichlet_alpha,
3537
lambda = lambda) {
3638

@@ -204,18 +206,42 @@ check_model = function(x,
204206
is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta))
205207
stop("Values for both scale parameters of the beta distribution need to be specified.")
206208
}
209+
207210
if(edge_prior == "Stochastic-Block") {
208211
theta = matrix(0.5, nrow = ncol(x), ncol = ncol(x))
209-
if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0 || dirichlet_alpha <= 0 || lambda <= 0)
210-
stop("The scale parameters of the beta and Dirichlet distribution need to be positive.")
211-
if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta) || !is.finite(dirichlet_alpha) || !is.finite(lambda))
212-
stop("The scale parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, and the rate parameter of the Poisson distribution need to be finite.")
213-
if(is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) ||
214-
is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta) ||
215-
is.null(dirichlet_alpha) || is.null(dirichlet_alpha) || is.null(lambda) || is.null(lambda))
216-
stop("Values for both scale parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, and the rate parameter of the Poisson distribution need to be specified.")
212+
213+
# Check that all beta parameters are provided
214+
if (is.null(beta_bernoulli_alpha) || is.null(beta_bernoulli_beta) ||
215+
is.null(beta_bernoulli_alpha_between) || is.null(beta_bernoulli_beta_between)) {
216+
stop("The Stochastic-Block prior requires all four beta parameters: ",
217+
"beta_bernoulli_alpha, beta_bernoulli_beta, ",
218+
"beta_bernoulli_alpha_between, and beta_bernoulli_beta_between.")
219+
}
220+
221+
# Check that all beta parameters are positive
222+
if (beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0 ||
223+
beta_bernoulli_alpha_between <= 0 || beta_bernoulli_beta_between <= 0 ||
224+
dirichlet_alpha <= 0 || lambda <= 0) {
225+
stop("The parameters of the beta and Dirichlet distributions need to be positive.")
226+
}
227+
228+
# Check that all beta parameters are finite
229+
if (!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta) ||
230+
!is.finite(beta_bernoulli_alpha_between) || !is.finite(beta_bernoulli_beta_between) ||
231+
!is.finite(dirichlet_alpha) || !is.finite(lambda)) {
232+
stop("The shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ",
233+
"and the rate parameter of the Poisson distribution need to be finite.")
234+
}
235+
236+
# Check for NAs
237+
if (is.na(beta_bernoulli_alpha) || is.na(beta_bernoulli_beta) ||
238+
is.na(beta_bernoulli_alpha_between) || is.na(beta_bernoulli_beta_between) ||
239+
is.na(dirichlet_alpha) || is.na(lambda)) {
240+
stop("Values for all shape parameters of the beta distribution, the concentration parameter of the Dirichlet distribution, ",
241+
"and the rate parameter of the Poisson distribution cannot be NA.")
242+
}
217243
}
218-
} else {
244+
}else {
219245
theta = matrix(0.5, nrow = 1, ncol = 1)
220246
edge_prior = "Not Applicable"
221247
}

R/output_utils.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ prepare_output_bgm = function(
22
out, x, num_categories, iter, data_columnnames, is_ordinal_variable,
33
warmup, pairwise_scale, main_alpha, main_beta,
44
na_action, na_impute, edge_selection, edge_prior, inclusion_probability,
5-
beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda,
5+
beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between,
6+
beta_bernoulli_beta_between,dirichlet_alpha, lambda,
67
variable_type, update_method, target_accept, hmc_num_leapfrogs,
78
nuts_max_depth, learn_mass_matrix, num_chains
89
) {
@@ -22,6 +23,8 @@ prepare_output_bgm = function(
2223
inclusion_probability = inclusion_probability,
2324
beta_bernoulli_alpha = beta_bernoulli_alpha,
2425
beta_bernoulli_beta = beta_bernoulli_beta,
26+
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
27+
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
2528
dirichlet_alpha = dirichlet_alpha,
2629
lambda = lambda,
2730
na_action = na_action,

man/bgm.Rd

Lines changed: 10 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/RcppExports.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ BEGIN_RCPP
5858
END_RCPP
5959
}
6060
// run_bgm_parallel
61-
Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, int seed, int progress_type);
62-
RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) {
61+
Rcpp::List run_bgm_parallel(const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, const std::string& edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, double beta_bernoulli_alpha_between, double beta_bernoulli_beta_between, double dirichlet_alpha, double lambda, const arma::imat& interaction_index_matrix, int iter, int warmup, const arma::imat& counts_per_category, const arma::imat& blume_capel_stats, double main_alpha, double main_beta, bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, const std::string& update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, int hmc_num_leapfrogs, int nuts_max_depth, bool learn_mass_matrix, int num_chains, int nThreads, int seed, int progress_type);
62+
RcppExport SEXP _bgms_run_bgm_parallel(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP pairwise_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP pairwise_statsSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP progress_typeSEXP) {
6363
BEGIN_RCPP
6464
Rcpp::RObject rcpp_result_gen;
6565
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -70,6 +70,8 @@ BEGIN_RCPP
7070
Rcpp::traits::input_parameter< const arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP);
7171
Rcpp::traits::input_parameter< double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP);
7272
Rcpp::traits::input_parameter< double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP);
73+
Rcpp::traits::input_parameter< double >::type beta_bernoulli_alpha_between(beta_bernoulli_alpha_betweenSEXP);
74+
Rcpp::traits::input_parameter< double >::type beta_bernoulli_beta_between(beta_bernoulli_beta_betweenSEXP);
7375
Rcpp::traits::input_parameter< double >::type dirichlet_alpha(dirichlet_alphaSEXP);
7476
Rcpp::traits::input_parameter< double >::type lambda(lambdaSEXP);
7577
Rcpp::traits::input_parameter< const arma::imat& >::type interaction_index_matrix(interaction_index_matrixSEXP);
@@ -95,7 +97,7 @@ BEGIN_RCPP
9597
Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP);
9698
Rcpp::traits::input_parameter< int >::type seed(seedSEXP);
9799
Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP);
98-
rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type));
100+
rcpp_result_gen = Rcpp::wrap(run_bgm_parallel(observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type));
99101
return rcpp_result_gen;
100102
END_RCPP
101103
}
@@ -182,7 +184,7 @@ END_RCPP
182184

183185
static const R_CallMethodDef CallEntries[] = {
184186
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 36},
185-
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 32},
187+
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 34},
186188
{"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0},
187189
{"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1},
188190
{"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1},

0 commit comments

Comments
 (0)