Skip to content

Commit cee8abd

Browse files
Created functions for computing marginal inclusion probabilities, stashed in archived code/gibbs_functions_moms.cpp
1 parent 2deb255 commit cee8abd

File tree

3 files changed

+17
-229
lines changed

3 files changed

+17
-229
lines changed

R/RcppExports.R

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ sample_bcomrf_gibbs <- function(no_states, no_variables, no_categories, interact
99
.Call(`_bgms_sample_bcomrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, variable_type, reference_category, iter)
1010
}
1111

12-
optimize_log_pseudoposterior_interaction <- function(initial_value, pairwise_effects, main_effects, observations, num_categories, num_persons, variable1, variable2, proposed_state, current_state, residual_matrix, is_ordinal_variable, reference_category, interaction_scale) {
13-
.Call(`_bgms_optimize_log_pseudoposterior_interaction`, initial_value, pairwise_effects, main_effects, observations, num_categories, num_persons, variable1, variable2, proposed_state, current_state, residual_matrix, is_ordinal_variable, reference_category, interaction_scale)
14-
}
15-
1612
run_gibbs_sampler_for_bgm <- function(observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main = FALSE, save_pairwise = FALSE, save_indicator = FALSE, display_progress = FALSE, edge_selection = TRUE, update_method = "adaptive-metropolis") {
1713
.Call(`_bgms_run_gibbs_sampler_for_bgm`, observations, num_categories, interaction_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, burnin, num_obs_categories, sufficient_blume_capel, threshold_alpha, threshold_beta, na_impute, missing_index, is_ordinal_variable, reference_category, save_main, save_pairwise, save_indicator, display_progress, edge_selection, update_method)
1814
}

src/RcppExports.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,6 @@ BEGIN_RCPP
4545
return rcpp_result_gen;
4646
END_RCPP
4747
}
48-
// optimize_log_pseudoposterior_interaction
49-
double optimize_log_pseudoposterior_interaction(const double initial_value, arma::mat& pairwise_effects, const arma::mat& main_effects, const arma::imat& observations, const arma::ivec& num_categories, const int num_persons, const int variable1, const int variable2, const double proposed_state, const double current_state, const arma::mat& residual_matrix, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, const double interaction_scale);
50-
RcppExport SEXP _bgms_optimize_log_pseudoposterior_interaction(SEXP initial_valueSEXP, SEXP pairwise_effectsSEXP, SEXP main_effectsSEXP, SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP num_personsSEXP, SEXP variable1SEXP, SEXP variable2SEXP, SEXP proposed_stateSEXP, SEXP current_stateSEXP, SEXP residual_matrixSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP interaction_scaleSEXP) {
51-
BEGIN_RCPP
52-
Rcpp::RObject rcpp_result_gen;
53-
Rcpp::RNGScope rcpp_rngScope_gen;
54-
Rcpp::traits::input_parameter< const double >::type initial_value(initial_valueSEXP);
55-
Rcpp::traits::input_parameter< arma::mat& >::type pairwise_effects(pairwise_effectsSEXP);
56-
Rcpp::traits::input_parameter< const arma::mat& >::type main_effects(main_effectsSEXP);
57-
Rcpp::traits::input_parameter< const arma::imat& >::type observations(observationsSEXP);
58-
Rcpp::traits::input_parameter< const arma::ivec& >::type num_categories(num_categoriesSEXP);
59-
Rcpp::traits::input_parameter< const int >::type num_persons(num_personsSEXP);
60-
Rcpp::traits::input_parameter< const int >::type variable1(variable1SEXP);
61-
Rcpp::traits::input_parameter< const int >::type variable2(variable2SEXP);
62-
Rcpp::traits::input_parameter< const double >::type proposed_state(proposed_stateSEXP);
63-
Rcpp::traits::input_parameter< const double >::type current_state(current_stateSEXP);
64-
Rcpp::traits::input_parameter< const arma::mat& >::type residual_matrix(residual_matrixSEXP);
65-
Rcpp::traits::input_parameter< const arma::uvec& >::type is_ordinal_variable(is_ordinal_variableSEXP);
66-
Rcpp::traits::input_parameter< const arma::ivec& >::type reference_category(reference_categorySEXP);
67-
Rcpp::traits::input_parameter< const double >::type interaction_scale(interaction_scaleSEXP);
68-
rcpp_result_gen = Rcpp::wrap(optimize_log_pseudoposterior_interaction(initial_value, pairwise_effects, main_effects, observations, num_categories, num_persons, variable1, variable2, proposed_state, current_state, residual_matrix, is_ordinal_variable, reference_category, interaction_scale));
69-
return rcpp_result_gen;
70-
END_RCPP
71-
}
7248
// run_gibbs_sampler_for_bgm
7349
List run_gibbs_sampler_for_bgm(arma::imat& observations, const arma::ivec& num_categories, const double interaction_scale, const String& edge_prior, arma::mat& inclusion_probability, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double dirichlet_alpha, const double lambda, const arma::imat& interaction_index_matrix, const int iter, const int burnin, arma::imat& num_obs_categories, arma::imat& sufficient_blume_capel, const double threshold_alpha, const double threshold_beta, const bool na_impute, const arma::imat& missing_index, const arma::uvec& is_ordinal_variable, const arma::ivec& reference_category, const bool save_main, const bool save_pairwise, const bool save_indicator, const bool display_progress, bool edge_selection, const std::string& update_method);
7450
RcppExport SEXP _bgms_run_gibbs_sampler_for_bgm(SEXP observationsSEXP, SEXP num_categoriesSEXP, SEXP interaction_scaleSEXP, SEXP edge_priorSEXP, SEXP inclusion_probabilitySEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP interaction_index_matrixSEXP, SEXP iterSEXP, SEXP burninSEXP, SEXP num_obs_categoriesSEXP, SEXP sufficient_blume_capelSEXP, SEXP threshold_alphaSEXP, SEXP threshold_betaSEXP, SEXP na_imputeSEXP, SEXP missing_indexSEXP, SEXP is_ordinal_variableSEXP, SEXP reference_categorySEXP, SEXP save_mainSEXP, SEXP save_pairwiseSEXP, SEXP save_indicatorSEXP, SEXP display_progressSEXP, SEXP edge_selectionSEXP, SEXP update_methodSEXP) {
@@ -167,7 +143,6 @@ END_RCPP
167143
static const R_CallMethodDef CallEntries[] = {
168144
{"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 6},
169145
{"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 8},
170-
{"_bgms_optimize_log_pseudoposterior_interaction", (DL_FUNC) &_bgms_optimize_log_pseudoposterior_interaction, 14},
171146
{"_bgms_run_gibbs_sampler_for_bgm", (DL_FUNC) &_bgms_run_gibbs_sampler_for_bgm, 26},
172147
{"_bgms_compare_anova_gibbs_sampler", (DL_FUNC) &_bgms_compare_anova_gibbs_sampler, 34},
173148
{"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4},

src/gibbs_functions.cpp

Lines changed: 17 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
using namespace Rcpp;
1010

1111

12-
1312
/**
1413
* Adapts the log step size using dual averaging during MCMC burn-in.
1514
*
@@ -1398,110 +1397,6 @@ double gradient_log_pseudoposterior_interaction_single (
13981397
}
13991398

14001399

1401-
double hessian_log_pseudoposterior_interaction_single (
1402-
int var1,
1403-
int var2,
1404-
const arma::mat& pairwise_effects,
1405-
const arma::mat& main_effects,
1406-
const arma::imat& observations,
1407-
const arma::ivec& num_categories,
1408-
const arma::uvec& is_ordinal_variable,
1409-
const arma::ivec& reference_category,
1410-
const double interaction_scale
1411-
) {
1412-
const int num_persons = observations.n_rows;
1413-
1414-
// Extract observed score vectors for each variable
1415-
arma::vec x_var1 = arma::conv_to<arma::vec>::from (observations.col (var1));
1416-
arma::vec x_var2 = arma::conv_to<arma::vec>::from (observations.col (var2));
1417-
1418-
// First-order gradient from data
1419-
double hessian = 0.0;
1420-
1421-
// --- Contribution from var1
1422-
int num_categories_var1 = num_categories (var1);
1423-
arma::vec rest_scores_var1 = observations * pairwise_effects.col (var1); // β_{var1,var1} = 0
1424-
arma::vec numerator_var1_E (num_persons, arma::fill::zeros);
1425-
arma::vec denominator_var1 (num_persons, arma::fill::zeros);
1426-
arma::vec numerator_var1_E2 (num_persons, arma::fill::zeros);
1427-
arma::vec bounds_var1 = arma::max (rest_scores_var1, arma::zeros<arma::vec> (num_persons)) * num_categories_var1;
1428-
1429-
if (is_ordinal_variable (var1)) {
1430-
denominator_var1 += arma::exp ( -bounds_var1 );
1431-
for (int category = 0; category < num_categories_var1; category++) {
1432-
arma::vec exponent = main_effects (var1, category) + (category + 1) * rest_scores_var1 - bounds_var1;
1433-
arma::vec weight = arma::exp (exponent);
1434-
denominator_var1 += weight;
1435-
numerator_var1_E += (category + 1) * x_var2 % weight;
1436-
numerator_var1_E2 += (category + 1) * (category + 1) * x_var2 % x_var2 % weight;
1437-
}
1438-
} else {
1439-
const int ref_cat = reference_category (var1);
1440-
for (int category = 0; category <= num_categories_var1; category++) {
1441-
int centered = category - ref_cat;
1442-
double lin_term = main_effects (var1, 0) * category;
1443-
double quad_term = main_effects (var1, 1) * centered * centered;
1444-
arma::vec exponent = lin_term + quad_term + category * rest_scores_var1 - bounds_var1;
1445-
arma::vec weight = arma::exp (exponent);
1446-
denominator_var1 += weight;
1447-
numerator_var1_E += category * x_var2 % weight;
1448-
numerator_var1_E2 += category * category * x_var2 % x_var2 % weight;
1449-
}
1450-
}
1451-
//- E((XiXj)^2)
1452-
hessian -= arma::accu (numerator_var1_E2 / denominator_var1);
1453-
1454-
//+E(XiXj)^2
1455-
arma::vec expectation = numerator_var1_E / denominator_var1;
1456-
hessian += arma::accu(arma::square(expectation));
1457-
1458-
// --- Contribution from var2
1459-
int num_categories_var2 = num_categories (var2);
1460-
arma::vec rest_scores_var2 = observations * pairwise_effects.col (var2);
1461-
arma::vec numerator_var2_E (num_persons, arma::fill::zeros);
1462-
arma::vec numerator_var2_E2 (num_persons, arma::fill::zeros);
1463-
arma::vec denominator_var2 (num_persons, arma::fill::zeros);
1464-
arma::vec bounds_var2 = arma::max (rest_scores_var2, arma::zeros<arma::vec> (num_persons)) * num_categories_var2;
1465-
1466-
if (is_ordinal_variable (var2)) {
1467-
denominator_var2 += arma::exp ( -bounds_var2 );
1468-
for (int category = 0; category < num_categories_var2; category++) {
1469-
arma::vec exponent = main_effects (var2, category) + (category + 1) * rest_scores_var2 - bounds_var2;
1470-
arma::vec weight = arma::exp (exponent);
1471-
denominator_var2 += weight;
1472-
numerator_var2_E += (category + 1) * x_var1 % weight;
1473-
numerator_var2_E2 += (category + 1) * (category + 1) * x_var1 % x_var1 % weight;
1474-
}
1475-
} else {
1476-
const int ref_cat = reference_category (var2);
1477-
for (int category = 0; category <= num_categories_var2; category++) {
1478-
int centered = category - ref_cat;
1479-
double lin_term = main_effects (var2, 0) * category;
1480-
double quad_term = main_effects (var2, 1) * centered * centered;
1481-
arma::vec exponent = lin_term + quad_term + category * rest_scores_var2 - bounds_var2;
1482-
arma::vec weight = arma::exp (exponent);
1483-
denominator_var2 += weight;
1484-
numerator_var2_E += category * x_var1 % weight;
1485-
numerator_var2_E2 += category * category * x_var1 % x_var1 % weight;
1486-
}
1487-
}
1488-
1489-
//- E((XiXj)^2)
1490-
hessian -= arma::accu (numerator_var2_E2 / denominator_var2);
1491-
1492-
//+E(XiXj)^2
1493-
expectation = numerator_var2_E / denominator_var2;
1494-
hessian += arma::accu(arma::square(expectation));
1495-
1496-
1497-
// --- Cauchy prior derivative
1498-
double beta = pairwise_effects (var1, var2) * pairwise_effects (var1, var2);
1499-
double s = interaction_scale * interaction_scale;
1500-
hessian += 2.0 * (beta - s) / ((beta + s) * (beta + s));
1501-
1502-
return hessian;
1503-
}
1504-
15051400

15061401
/**
15071402
* Function: log_pseudoposterior_interactions
@@ -1542,7 +1437,7 @@ double log_pseudoposterior_interactions (
15421437
arma::mat real_observations = arma::conv_to<arma::mat>::from (observations);
15431438

15441439
// Leading term: trace(X * B * X^T)
1545-
double log_pseudo_likelihood = arma::trace (real_observations * pairwise_effects * real_observations.t ());
1440+
double log_pseudo_posterior = arma::trace (real_observations * pairwise_effects * real_observations.t ());
15461441

15471442
for (int var = 0; var < num_variables; var++) {
15481443
int num_categories_var = num_categories (var);
@@ -1572,105 +1467,25 @@ double log_pseudoposterior_interactions (
15721467
}
15731468

15741469
// Subtract log partition function and bounds adjustment
1575-
log_pseudo_likelihood -= arma::accu (arma::log (denominator));
1576-
log_pseudo_likelihood -= arma::accu (bounds);
1470+
log_pseudo_posterior -= arma::accu (arma::log (denominator));
1471+
log_pseudo_posterior -= arma::accu (bounds);
15771472
}
15781473

15791474
// Add Cauchy prior terms for included pairwise effects
15801475
for (int var1 = 0; var1 < num_variables - 1; var1++) {
15811476
for (int var2 = var1 + 1; var2 < num_variables; var2++) {
15821477
if (inclusion_indicator (var1, var2) == 1) {
1583-
log_pseudo_likelihood += R::dcauchy (pairwise_effects (var1, var2), 0.0, interaction_scale, true);
1478+
log_pseudo_posterior += R::dcauchy (pairwise_effects (var1, var2), 0.0, interaction_scale, true);
15841479
}
15851480
}
15861481
}
15871482

1588-
return log_pseudo_likelihood;
1483+
return log_pseudo_posterior;
15891484
}
15901485

15911486

1592-
/**
1593-
*
1594-
*
1595-
*
1596-
*/
1597-
//[[Rcpp::export]]
1598-
double optimize_log_pseudoposterior_interaction (
1599-
const double initial_value,
1600-
arma::mat& pairwise_effects,
1601-
const arma::mat& main_effects,
1602-
const arma::imat& inclusion_indicator,
1603-
const arma::imat& observations,
1604-
const arma::ivec& num_categories,
1605-
const int num_persons,
1606-
const int variable1,
1607-
const int variable2,
1608-
const double proposed_state,
1609-
const double current_state,
1610-
const arma::mat& residual_matrix,
1611-
const arma::uvec& is_ordinal_variable,
1612-
const arma::ivec& reference_category,
1613-
const double interaction_scale
1614-
) {
1615-
1616-
double x = initial_value;
16171487

1618-
const int max_steps = 10;
1619-
const double tolerance = 1e-6;//sqrt (std::numeric_limits<double>::epsilon ());
16201488

1621-
const double x0 = pairwise_effects(variable1, variable2);
1622-
double hessian_at_x;
1623-
// find mode
1624-
for (int t = 0; t < max_steps; t++) {
1625-
1626-
// TODO: need to assign x to pairwise_effects[variable1, variable2]
1627-
pairwise_effects(variable1, variable2) = x;
1628-
pairwise_effects(variable2, variable1) = x;
1629-
Rcpp::Rcout << "t: " << t << " x: " << x << std::endl;
1630-
double gradient_at_x = gradient_log_pseudoposterior_interaction_single (
1631-
variable1, variable2, pairwise_effects, main_effects, observations,
1632-
num_categories, is_ordinal_variable, reference_category, interaction_scale
1633-
);
1634-
1635-
Rcpp::Rcout << "hessian_at_x" << std::endl;
1636-
hessian_at_x = hessian_log_pseudoposterior_interaction_single (
1637-
variable1, variable2, pairwise_effects, main_effects, observations,
1638-
num_categories, is_ordinal_variable, reference_category, interaction_scale
1639-
);
1640-
1641-
// double x_new = x - gradient_at_x / hessian_at_x;
1642-
double x_new = x - gradient_at_x / hessian_at_x;
1643-
1644-
if (std::abs(x_new - x) < tolerance) {
1645-
x = x_new;
1646-
break;
1647-
}
1648-
x = x_new;
1649-
1650-
}
1651-
1652-
pairwise_effects(variable1, variable2) = x;
1653-
pairwise_effects(variable2, variable1) = x;
1654-
1655-
const double fx = log_pseudoposterior_interactions(
1656-
pairwise_effects,
1657-
main_effects,
1658-
observations,
1659-
num_categories,
1660-
inclusion_indicator,
1661-
is_ordinal_variable,
1662-
reference_category,
1663-
interaction_scale
1664-
);
1665-
1666-
pairwise_effects(variable1, variable2) = x0;
1667-
pairwise_effects(variable2, variable1) = x0;
1668-
1669-
// @maarten not sure if you need or want both?
1670-
const double log_integral = fx + (log(2 * M_PI) - log(-hessian_at_x)) / 2;
1671-
return x;
1672-
1673-
}
16741489

16751490

16761491

@@ -2797,6 +2612,7 @@ void update_indicator_interaction_pair_with_fisher_mala (
27972612
}
27982613

27992614

2615+
28002616
/**
28012617
* Performs a single iteration of the Gibbs sampler for graphical model parameters.
28022618
*
@@ -2894,7 +2710,8 @@ void gibbs_update_step_for_graphical_model_parameters (
28942710
arma::mat& sqrt_inv_fisher_pairwise,
28952711
const std::string& update_method,
28962712
arma::vec& cached_interaction_gradient,
2897-
bool& gradient_valid
2713+
bool& gradient_valid,
2714+
arma::vec& posterior_prob
28982715
) {
28992716
// --- Robbins-Monro weight for adaptive Metropolis updates
29002717
const double exp_neg_log_t_rm_adaptation_rate =
@@ -3088,11 +2905,11 @@ List run_gibbs_sampler_for_bgm (
30882905
const int num_main = count_num_main_effects(num_categories, is_ordinal_variable);
30892906
arma::mat* main_effect_samples = nullptr;
30902907
arma::mat* pairwise_effect_samples = nullptr;
3091-
arma::imat* indicator_samples = nullptr;
2908+
arma::mat* indicator_samples = nullptr;
30922909

30932910
if (save_main) main_effect_samples = new arma::mat(iter, num_main);
30942911
if (save_pairwise) pairwise_effect_samples = new arma::mat(iter, num_pairwise);
3095-
if (save_indicator) indicator_samples = new arma::imat(iter, num_pairwise);
2912+
if (save_indicator) indicator_samples = new arma::mat(iter, num_pairwise);
30962913

30972914
// Initialize proposal SDs and MALA tracking
30982915
arma::mat proposal_sd_main(num_main, 2, arma::fill::ones);
@@ -3175,7 +2992,7 @@ List run_gibbs_sampler_for_bgm (
31752992
}
31762993
arma::vec cached_interaction_gradient; // will hold a cached gradient vector
31772994
bool gradient_valid = false; // indicates whether the cache is valid
3178-
2995+
arma::vec posterior_prob(num_pairwise);
31792996

31802997
// --- Set up total number of iterations (burn-in + sampling)
31812998
bool enable_edge_selection = edge_selection;
@@ -3225,7 +3042,7 @@ List run_gibbs_sampler_for_bgm (
32253042
dual_averaging_main, total_burnin, initial_step_size_main,
32263043
sqrt_inv_fisher_main, step_size_pairwise, dual_averaging_pairwise,
32273044
initial_step_size_pairwise, sqrt_inv_fisher_pairwise, update_method,
3228-
cached_interaction_gradient, gradient_valid
3045+
cached_interaction_gradient, gradient_valid, posterior_prob
32293046
);
32303047

32313048
// --- Update edge probabilities under the prior (if edge selection is active)
@@ -3295,11 +3112,11 @@ List run_gibbs_sampler_for_bgm (
32953112
}
32963113

32973114
if (save_indicator) {
3298-
arma::ivec vectorized_indicator(num_pairwise);
3299-
for (int i = 0; i < num_pairwise; i++) {
3300-
vectorized_indicator(i) = inclusion_indicator(interaction_index_matrix(i, 1), interaction_index_matrix(i, 2));
3301-
}
3302-
indicator_samples->row(sample_index) = vectorized_indicator.t();
3115+
//arma::ivec vectorized_indicator(num_pairwise);
3116+
//for (int i = 0; i < num_pairwise; i++) {
3117+
// vectorized_indicator(i) = inclusion_indicator(interaction_index_matrix(i, 1), interaction_index_matrix(i, 2));
3118+
//}
3119+
indicator_samples->row(sample_index) = posterior_prob.t();//vectorized_indicator.t();
33033120
}
33043121

33053122
if (edge_prior == "Stochastic-Block") {

0 commit comments

Comments
 (0)