Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ S3method(print,summary.bgms)
S3method(summary,bgms)
export(bgm)
export(bgmCompare)
export(bgmCompare2)
export(extract_arguments)
export(extract_category_thresholds.bgms)
export(extract_edge_indicators)
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads) {
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads)
run_bgmCompare_parallel <- function(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed) {
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed)
}

run_bgm_parallel <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, edge_selection, update_method, pairwise_effect_indices, target_accept, sufficient_pairwise, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed) {
Expand Down
38 changes: 4 additions & 34 deletions R/bgmCompare2.R
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ bgmCompare2 = function(
stop("Argument 'seed' must be a non-negative integer or vector of non-negative integers.")
}
# Force to integer type
seed <- as.integer(seed)
seed = as.integer(seed)
dqrng::dqset.seed(seed)
}

Expand Down Expand Up @@ -225,44 +225,14 @@ bgmCompare2 = function(
nuts_max_depth = nuts_max_depth,
learn_mass_matrix = learn_mass_matrix,
projection = projection,
group_membership = sorted_group - 1, ######################################
group_membership = sorted_group - 1,
group_indices = group_indices,
interaction_index_matrix = Index,
inclusion_probability = model$inclusion_probability_difference,
num_chains = chains, nThreads = cores
num_chains = chains, nThreads = cores,
seed = seed
)

# out = run_gibbs_sampler_for_bgmCompare(
# chain_id = 1,
# observations = observations,
# num_groups = num_groups,
# num_obs_categories = num_obs_categories,
# sufficient_blume_capel = sufficient_blume_capel,
# sufficient_pairwise = sufficient_pairwise,
# num_categories = num_categories[, 1],
# main_alpha = threshold_alpha,
# main_beta = threshold_beta,
# pairwise_scale = interaction_scale,
# difference_scale = difference_scale,
# difference_selection_alpha = beta_bernoulli_alpha,
# difference_selection_beta = beta_bernoulli_beta,
# difference_prior = model$difference_prior, iter = iter, burnin = burnin,
# na_impute = na_impute, missing_data_indices = missing_index,
# is_ordinal_variable = ordinal_variable,
# baseline_category = reference_category,
# difference_selection = difference_selection,
# main_effect_indices = main_effect_indices,
# pairwise_effect_indices = pairwise_effect_indices,
# target_accept = target_accept,
# nuts_max_depth = nuts_max_depth,
# learn_mass_matrix = learn_mass_matrix,
# projection = projection,
# group_membership = sorted_group - 1, ######################################
# group_indices = group_indices,
# interaction_index_matrix = Index,
# inclusion_probability = model$inclusion_probability_difference)


# Main output handler in the wrapper function
# output = prepare_output_bgmCompare2(
# out = out, ...
Expand Down
11 changes: 7 additions & 4 deletions R/function_input_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,10 @@ check_compare2_model = function(
if(difference_inclusion_probability >= 1)
stop("The inclusion probability for differences cannot equal or exceed the value one.")

inclusion_probability_difference = inclusion_probability_difference +
difference_probability
inclusion_probability_difference = matrix(difference_probability,
nrow = ncol(x),
ncol = ncol(x))

} else {
if(!inherits(difference_probability, what = "matrix") &&
!inherits(difference_probability, what = "data.frame"))
Expand All @@ -765,8 +767,9 @@ check_compare2_model = function(
stop("One or more inclusion probabilities for differences are one or larger.")
}
} else {
inclusion_probability_difference = inclusion_probability_difference + 0.5

inclusion_probability_difference = matrix(0.5,
nrow = ncol(x),
ncol = ncol(x))
if(beta_bernoulli_alpha <= 0 || beta_bernoulli_beta <= 0)
stop("The scale parameters of the beta distribution for the differences need to be positive.")
if(!is.finite(beta_bernoulli_alpha) || !is.finite(beta_bernoulli_beta))
Expand Down
3 changes: 2 additions & 1 deletion man/bgm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 0 additions & 38 deletions man/summarySBM.Rd

This file was deleted.

4 changes: 2 additions & 2 deletions src/Makevars
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CXX_STD = CXX14
CXX_STD = CXX20

## Pull in the include-paths for RcppParallel
PKG_CPPFLAGS = \
Expand All @@ -11,4 +11,4 @@ PKG_CPPFLAGS = \
PKG_LIBS = \
$(shell "${R_HOME}/bin/Rscript" -e "cat(RcppParallel::LdFlags())") \
-Wl,-rpath,$(shell "${R_HOME}/bin/Rscript" -e "cat(system.file('lib',package='RcppParallel'))") \
$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)
41 changes: 21 additions & 20 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,45 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// run_bgmCompare_parallel
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, const int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, const double main_alpha, const double main_beta, const double pairwise_scale, const double difference_scale, const double difference_selection_alpha, const double difference_selection_beta, const std::string& difference_prior, const int iter, const int burnin, const bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, const double target_accept, const int nuts_max_depth, const bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, const int num_chains, const int nThreads);
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP) {
Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector<arma::imat>& num_obs_categories, const std::vector<arma::imat>& sufficient_blume_capel, const std::vector<arma::mat>& sufficient_pairwise, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int burnin, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed);
RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP sufficient_pairwiseSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::imat& >::type observations(observationsSEXP);
Rcpp::traits::input_parameter< const int >::type num_groups(num_groupsSEXP);
Rcpp::traits::input_parameter< int >::type num_groups(num_groupsSEXP);
Rcpp::traits::input_parameter< const std::vector<arma::imat>& >::type num_obs_categories(num_obs_categoriesSEXP);
Rcpp::traits::input_parameter< const std::vector<arma::imat>& >::type sufficient_blume_capel(sufficient_blume_capelSEXP);
Rcpp::traits::input_parameter< const std::vector<arma::mat>& >::type sufficient_pairwise(sufficient_pairwiseSEXP);
Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP);
Rcpp::traits::input_parameter< const double >::type main_alpha(main_alphaSEXP);
Rcpp::traits::input_parameter< const double >::type main_beta(main_betaSEXP);
Rcpp::traits::input_parameter< const double >::type pairwise_scale(pairwise_scaleSEXP);
Rcpp::traits::input_parameter< const double >::type difference_scale(difference_scaleSEXP);
Rcpp::traits::input_parameter< const double >::type difference_selection_alpha(difference_selection_alphaSEXP);
Rcpp::traits::input_parameter< const double >::type difference_selection_beta(difference_selection_betaSEXP);
Rcpp::traits::input_parameter< double >::type main_alpha(main_alphaSEXP);
Rcpp::traits::input_parameter< double >::type main_beta(main_betaSEXP);
Rcpp::traits::input_parameter< double >::type pairwise_scale(pairwise_scaleSEXP);
Rcpp::traits::input_parameter< double >::type difference_scale(difference_scaleSEXP);
Rcpp::traits::input_parameter< double >::type difference_selection_alpha(difference_selection_alphaSEXP);
Rcpp::traits::input_parameter< double >::type difference_selection_beta(difference_selection_betaSEXP);
Rcpp::traits::input_parameter< const std::string& >::type difference_prior(difference_priorSEXP);
Rcpp::traits::input_parameter< const int >::type iter(iterSEXP);
Rcpp::traits::input_parameter< const int >::type burnin(burninSEXP);
Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP);
Rcpp::traits::input_parameter< int >::type iter(iterSEXP);
Rcpp::traits::input_parameter< int >::type burnin(burninSEXP);
Rcpp::traits::input_parameter< bool >::type na_impute(na_imputeSEXP);
Rcpp::traits::input_parameter< const arma::imat& >::type missing_data_indices(missing_data_indicesSEXP);
Rcpp::traits::input_parameter< const arma::uvec& >::type is_ordinal_variable(is_ordinal_variableSEXP);
Rcpp::traits::input_parameter< const arma::ivec& >::type baseline_category(baseline_categorySEXP);
Rcpp::traits::input_parameter< const bool >::type difference_selection(difference_selectionSEXP);
Rcpp::traits::input_parameter< bool >::type difference_selection(difference_selectionSEXP);
Rcpp::traits::input_parameter< const arma::imat& >::type main_effect_indices(main_effect_indicesSEXP);
Rcpp::traits::input_parameter< const arma::imat& >::type pairwise_effect_indices(pairwise_effect_indicesSEXP);
Rcpp::traits::input_parameter< const double >::type target_accept(target_acceptSEXP);
Rcpp::traits::input_parameter< const int >::type nuts_max_depth(nuts_max_depthSEXP);
Rcpp::traits::input_parameter< const bool >::type learn_mass_matrix(learn_mass_matrixSEXP);
Rcpp::traits::input_parameter< double >::type target_accept(target_acceptSEXP);
Rcpp::traits::input_parameter< int >::type nuts_max_depth(nuts_max_depthSEXP);
Rcpp::traits::input_parameter< bool >::type learn_mass_matrix(learn_mass_matrixSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type projection(projectionSEXP);
Rcpp::traits::input_parameter< const arma::ivec& >::type group_membership(group_membershipSEXP);
Rcpp::traits::input_parameter< const arma::imat& >::type group_indices(group_indicesSEXP);
Rcpp::traits::input_parameter< const arma::imat& >::type interaction_index_matrix(interaction_index_matrixSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type inclusion_probability(inclusion_probabilitySEXP);
Rcpp::traits::input_parameter< const int >::type num_chains(num_chainsSEXP);
Rcpp::traits::input_parameter< const int >::type nThreads(nThreadsSEXP);
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads));
Rcpp::traits::input_parameter< int >::type num_chains(num_chainsSEXP);
Rcpp::traits::input_parameter< int >::type nThreads(nThreadsSEXP);
Rcpp::traits::input_parameter< int >::type seed(seedSEXP);
rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, num_obs_categories, sufficient_blume_capel, sufficient_pairwise, num_categories, main_alpha, main_beta, pairwise_scale, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, burnin, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -188,7 +189,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 32},
{"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 33},
{"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 31},
{"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6},
{"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8},
Expand Down
Loading