diff --git a/src/bgmCompare_logp_and_grad.cpp b/src/bgmCompare_logp_and_grad.cpp index 32794912..72d0afb5 100644 --- a/src/bgmCompare_logp_and_grad.cpp +++ b/src/bgmCompare_logp_and_grad.cpp @@ -148,12 +148,12 @@ double log_pseudoposterior( if (is_ordinal_variable(v)) { // base term exp(-bound) - denom = arma::exp(-bound); + denom = ARMA_MY_EXP(-bound); // main_effects from main_group for (int c = 0; c < num_cats; ++c) { const double th = main_group(v, c); const arma::vec exponent = th + (c + 1) * rest_score - bound; - denom += arma::exp(exponent); + denom += ARMA_MY_EXP(exponent); } } else { // linear/quadratic main effects from main_group @@ -165,11 +165,11 @@ double log_pseudoposterior( const double quad = quad_effect * centered * centered; const double lin = lin_effect * c; const arma::vec exponent = lin + quad + c * rest_score - bound; - denom += arma::exp(exponent); + denom += ARMA_MY_EXP(exponent); } } // - sum_i [ bound_i + log denom_i ] - log_pp -= arma::accu(bound + arma::log(denom)); + log_pp -= arma::accu(bound + ARMA_MY_LOG(denom)); } } @@ -281,8 +281,8 @@ arma::vec gradient( const double main_beta, const double interaction_scale, const double difference_scale, - const arma::imat main_index, - const arma::imat pair_index + const arma::imat& main_index, + const arma::imat& pair_index ) { const int num_variables = observations.n_cols; const int max_num_categories = num_categories.max(); @@ -300,15 +300,21 @@ arma::vec gradient( total_len += static_cast(r1 - r0 + 1) * (num_groups - 1); } } - for (int v1 = 0; v1 < num_variables - 1; ++v1) { - for (int v2 = v1 + 1; v2 < num_variables; ++v2) { - if (inclusion_indicator(v1, v2) == 1) total_len += (num_groups - 1); + for (int v2 = 0; v2 < num_variables - 1; ++v2) { + for (int v1 = v2 + 1; v1 < num_variables; ++v1) { + total_len += (inclusion_indicator(v1, v2) == 1) * (num_groups - 1); } } arma::vec grad(total_len, arma::fill::zeros); int off; + // ------------------------------------------------- + // Allocate temporaries ONCE (reused inside loops) + // ------------------------------------------------- + arma::mat main_group(num_variables, max_num_categories, arma::fill::none); + arma::mat pairwise_group(num_variables, num_variables, arma::fill::none); + // ------------------------------- // Observed sufficient statistics // ------------------------------- @@ -324,15 +330,14 @@ arma::vec gradient( if (is_ordinal_variable(v)) { for (int c = 0; c < num_cats; ++c) { - // overall off = main_index(base + c, 0); grad(off) += counts_per_category(c, v); - // diffs if (inclusion_indicator(v, v) != 0) { + const int cnt = counts_per_category(c, v); for (int k = 1; k < num_groups; ++k) { off = main_index(base + c, k); - grad(off) += counts_per_category(c, v) * projection(g, k - 1); + grad(off) += cnt * projection(g, k - 1); } } } @@ -344,7 +349,6 @@ arma::vec gradient( off = main_index(base + 1, 0); grad(off) += blume_capel_stats(1, v); - // diffs if (inclusion_indicator(v, v) != 0) { for (int k = 1; k < num_groups; ++k) { off = main_index(base, k); @@ -382,9 +386,9 @@ arma::vec gradient( const int r0 = group_indices(g, 0); const int r1 = group_indices(g, 1); - arma::mat main_group(num_variables, max_num_categories, arma::fill::zeros); - arma::mat pairwise_group(num_variables, num_variables, arma::fill::zeros); const arma::vec proj_g = projection.row(g).t(); // length = num_groups-1 + main_group.zeros(); + pairwise_group.zeros(); // build group-specific params for (int v = 0; v < num_variables; ++v) { @@ -393,8 +397,7 @@ arma::vec gradient( ); main_group(v, arma::span(0, me.n_elem - 1)) = me.t(); - for (int u = v; u < num_variables; ++u) { // Combines with loop over v - if(u == v) continue; + for (int u = v + 1; u < num_variables; ++u) { double w = compute_group_pairwise_effects( v, u, num_groups, pairwise_effects, pairwise_effect_indices, inclusion_indicator, proj_g @@ -415,15 +418,14 @@ arma::vec gradient( arma::vec rest_score = residual_matrix.col(v); arma::vec bound = K * rest_score; - bound = arma::clamp(bound, 0.0, arma::datum::inf); + bound.clamp(0.0, arma::datum::inf); - arma::mat exponents(num_group_obs, K + 1, arma::fill::zeros); + arma::mat exponents(num_group_obs, K + 1, arma::fill::none); if (is_ordinal_variable(v)) { - exponents.col(0) -= bound; - arma::vec main_param = main_group.row(v).cols(0, K - 1).t(); - for (int j = 0; j < K; j++) { - exponents.col(j+1) = main_param(j) + (j + 1) * rest_score - bound; + exponents.col(0) = -bound; + for (int j = 0; j < K; ++j) { + exponents.col(j + 1) = main_group(v, j) + (j + 1) * rest_score - bound; } } else { const double lin_effect = main_group(v, 0); @@ -442,35 +444,43 @@ arma::vec gradient( // ---- MAIN expected ---- const int base = main_effect_indices(v, 0); - if (is_ordinal_variable(v)) { for (int s = 1; s <= K; ++s) { const int j = s - 1; + double sum_col_s = arma::accu(probs.col(s)); + + off = main_index(base + j, 0); - grad(off) -= arma::accu(probs.col(s)); + grad(off) -= sum_col_s; + if (inclusion_indicator(v, v) == 0) continue; for (int k = 1; k < num_groups; ++k) { off = main_index(base + j, k); - grad(off) -= projection(g, k - 1) * arma::accu(probs.col(s)); + grad(off) -= projection(g, k - 1) * sum_col_s; + } } } else { + arma::vec lin_score = arma::regspace(0, K); // length K+1 arma::vec quad_score = arma::square(lin_score - ref); - off = main_index(base, 0); - grad(off) -= arma::accu(probs * lin_score); + double sum_lin = arma::accu(probs * lin_score); + double sum_quad = arma::accu(probs * quad_score); + off = main_index(base, 0); + grad(off) -= sum_lin; off = main_index(base + 1, 0); - grad(off) -= arma::accu(probs * quad_score); + grad(off) -= sum_quad; if (inclusion_indicator(v, v) == 0) continue; for (int k = 1; k < num_groups; ++k) { off = main_index(base, k); - grad(off) -= projection(g, k - 1) * arma::accu(probs * lin_score); + grad(off) -= projection(g, k - 1) * sum_lin; + off = main_index(base + 1, k); - grad(off) -= projection(g, k - 1) * arma::accu(probs * quad_score); + grad(off) -= projection(g, k - 1) * sum_quad; } } @@ -478,22 +488,28 @@ arma::vec gradient( for (int v2 = 0; v2 < num_variables; ++v2) { if (v == v2) continue; - const int row = (v < v2) ? pairwise_effect_indices(v, v2) - : pairwise_effect_indices(v2, v); - arma::vec expected_value(num_group_obs, arma::fill::zeros); - for (int s = 1; s <= K; ++s) { expected_value += s * probs.col(s) % obs.col(v2); } + double sum_expectation = arma::accu(expected_value); + + // this is mathematically equivalent but numerically different... + // double sum_expectation = 0.0; + // for (int s = 1; s <= K; ++s) { + // sum_expectation += s * arma::dot(probs.col(s), obs.col(v2)); + // } + + const int row = (v < v2) ? pairwise_effect_indices(v, v2) + : pairwise_effect_indices(v2, v); off = pair_index(row, 0); - grad(off) -= arma::accu(expected_value); + grad(off) -= sum_expectation; if (inclusion_indicator(v, v2) == 0) continue; for (int k = 1; k < num_groups; ++k) { off = pair_index(row, k); - grad(off) -= projection(g, k - 1) * arma::accu(expected_value); + grad(off) -= projection(g, k - 1) * sum_expectation; } } } @@ -706,12 +722,12 @@ double log_pseudoposterior_main_component( arma::vec denom(rest_score.n_elem, arma::fill::zeros); if (is_ordinal_variable(variable)) { // base term exp(-bound) - denom = arma::exp(-bound); + denom = ARMA_MY_EXP(-bound); // main_effects from main_group for (int cat = 0; cat < num_cats; cat++) { const double th = main_group(variable, cat); const arma::vec exponent = th + (cat + 1) * rest_score - bound; - denom += arma::exp(exponent); + denom += ARMA_MY_EXP(exponent); } } else { // linear/quadratic main effects from main_group @@ -723,11 +739,11 @@ double log_pseudoposterior_main_component( const double quad = quad_effect * centered * centered; const double lin = lin_effect * cat; const arma::vec exponent = lin + quad + cat * rest_score - bound; - denom += arma::exp(exponent); + denom += ARMA_MY_EXP(exponent); } } // - sum_i [ bound_i + log denom_i ] - log_pp -= arma::accu(bound + arma::log(denom)); + log_pp -= arma::accu(bound + ARMA_MY_LOG(denom)); } // ---- priors ---- @@ -890,12 +906,12 @@ double log_pseudoposterior_pair_component( if (is_ordinal_variable(v)) { // base term exp(-bound) - denom = arma::exp(-bound); + denom = ARMA_MY_EXP(-bound); // main_effects from main_group for (int c = 0; c < num_cats; ++c) { const double th = main_group(v, c); const arma::vec exponent = th + (c + 1) * rest_score - bound; - denom += arma::exp(exponent); + denom += ARMA_MY_EXP(exponent); } } else { // linear/quadratic main effects from main_group @@ -907,11 +923,11 @@ double log_pseudoposterior_pair_component( const double quad = quad_effect * centered * centered; const double lin = lin_effect * c; const arma::vec exponent = lin + quad + c * rest_score - bound; - denom += arma::exp(exponent); + denom += ARMA_MY_EXP(exponent); } } // - sum_i [ bound_i + log denom_i ] - log_pp -= arma::accu(bound + arma::log(denom)); + log_pp -= arma::accu(bound + ARMA_MY_LOG(denom)); } } @@ -1037,12 +1053,12 @@ double log_ratio_pseudolikelihood_constant_variable( bound_current = num_cats * arma::clamp(rest_current, 0.0, arma::datum::inf); bound_proposed = num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf); - denom_current = arma::exp(-bound_current); - denom_proposed = arma::exp(-bound_proposed); + denom_current = ARMA_MY_EXP(-bound_current); + denom_proposed = ARMA_MY_EXP(-bound_proposed); for (int c = 0; c < num_cats; ++c) { - denom_current += arma::exp(main_current(c) + (c + 1) * rest_current - bound_current); - denom_proposed += arma::exp(main_proposed(c) + (c + 1) * rest_proposed - bound_proposed); + denom_current += ARMA_MY_EXP(main_current(c) + (c + 1) * rest_current - bound_current); + denom_proposed += ARMA_MY_EXP(main_proposed(c) + (c + 1) * rest_proposed - bound_proposed); } } else { // Blume-Capel: linear + quadratic @@ -1063,14 +1079,14 @@ double log_ratio_pseudolikelihood_constant_variable( bound_proposed = lbound + num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf); for (int s = 0; s <= num_cats; ++s) { - denom_current += arma::exp(const_current(s) + s * rest_current - bound_current); - denom_proposed += arma::exp(const_proposed(s) + s * rest_proposed - bound_proposed); + denom_current += ARMA_MY_EXP(const_current(s) + s * rest_current - bound_current); + denom_proposed += ARMA_MY_EXP(const_proposed(s) + s * rest_proposed - bound_proposed); } } // --- accumulate contribution --- log_ratio += arma::accu((bound_current - bound_proposed) + - arma::log(denom_current) - arma::log(denom_proposed)); + ARMA_MY_LOG(denom_current) - ARMA_MY_LOG(denom_proposed)); } return log_ratio; diff --git a/src/bgmCompare_logp_and_grad.h b/src/bgmCompare_logp_and_grad.h index 332e5e74..5d054482 100644 --- a/src/bgmCompare_logp_and_grad.h +++ b/src/bgmCompare_logp_and_grad.h @@ -46,8 +46,8 @@ arma::vec gradient( const double main_beta, const double interaction_scale, const double difference_scale, - const arma::imat main_index, - const arma::imat pair_index + const arma::imat& main_index, + const arma::imat& pair_index ); double log_pseudoposterior_main_component( diff --git a/src/bgmCompare_parallel.cpp b/src/bgmCompare_parallel.cpp index 748910b2..75c1b2af 100644 --- a/src/bgmCompare_parallel.cpp +++ b/src/bgmCompare_parallel.cpp @@ -9,6 +9,7 @@ #include "progress_manager.h" #include "sampler_output.h" #include "mcmc_adaptation.h" +#include "common_helpers.h" using namespace Rcpp; using namespace RcppParallel; @@ -131,7 +132,7 @@ struct GibbsCompareChainRunner : public Worker { const arma::mat& inclusion_probability_master; // RNG seeds const std::vector& chain_rngs; - const std::string& update_method; + const UpdateMethod update_method; const int hmc_num_leapfrogs; ProgressManager& pm; // output @@ -169,7 +170,7 @@ struct GibbsCompareChainRunner : public Worker { const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability_master, const std::vector& chain_rngs, - const std::string& update_method, + const UpdateMethod update_method, const int hmc_num_leapfrogs, ProgressManager& pm, std::vector& results @@ -394,8 +395,10 @@ Rcpp::List run_bgmCompare_parallel( chain_rngs[c] = SafeRNG(seed + c); } + UpdateMethod update_method_enum = update_method_from_string(update_method); + // only used to determine the total no. warmup iterations, a bit hacky - WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method != "adaptive-metropolis")); + WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method_enum != adaptive_metropolis)); int total_warmup = warmup_schedule_temp.total_warmup; ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type); @@ -408,7 +411,7 @@ Rcpp::List run_bgmCompare_parallel( baseline_category, difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, - inclusion_probability, chain_rngs, update_method, hmc_num_leapfrogs, + inclusion_probability, chain_rngs, update_method_enum, hmc_num_leapfrogs, pm, results ); diff --git a/src/bgmCompare_sampler.cpp b/src/bgmCompare_sampler.cpp index 55187591..cf7619d7 100644 --- a/src/bgmCompare_sampler.cpp +++ b/src/bgmCompare_sampler.cpp @@ -1166,12 +1166,11 @@ void update_indicator_differences_metropolis_bgmcompare ( // Add prior inclusion probability contribution double inc_prob = inclusion_probability_difference(var, var); + double logit_inc_prob = MY_LOG(inc_prob / (1 - inc_prob)); if(proposed_ind == 1) { - log_accept += MY_LOG(inc_prob); - log_accept -= MY_LOG(1 - inc_prob); + log_accept += logit_inc_prob; } else { - log_accept -= MY_LOG(inc_prob); - log_accept += MY_LOG(1 - inc_prob); + log_accept -= logit_inc_prob; } // Add parameter prior contribution @@ -1244,12 +1243,11 @@ void update_indicator_differences_metropolis_bgmcompare ( // Add prior inclusion probability contribution double inc_prob = inclusion_probability_difference(var1, var2); + double logit_inc_prob = MY_LOG(inc_prob / (1 - inc_prob)); if(proposed_ind == 1) { - log_accept += MY_LOG(inc_prob); - log_accept -= MY_LOG(1 - inc_prob); + log_accept += logit_inc_prob; } else { - log_accept -= MY_LOG(inc_prob); - log_accept += MY_LOG(1 - inc_prob); + log_accept -= logit_inc_prob; } // Add parameter prior contribution @@ -1389,7 +1387,7 @@ void gibbs_update_step_bgmcompare ( SafeRNG& rng, arma::mat& inclusion_probability, int hmc_nuts_leapfrogs, - const std::string& update_method, + const UpdateMethod update_method, arma::mat& proposal_sd_main, arma::mat& proposal_sd_pair, const arma::imat& index @@ -1416,7 +1414,7 @@ void gibbs_update_step_bgmcompare ( } // Step 2: Update parameters - if(update_method == "adaptive-metropolis") { + if(update_method == adaptive_metropolis) { update_main_effects_metropolis_bgmcompare ( main_effects, pairwise_effects, main_effect_indices, pairwise_effect_indices, inclusion_indicator, projection, @@ -1434,7 +1432,7 @@ void gibbs_update_step_bgmcompare ( pairwise_scale, difference_scale, iteration, rwm_adapt_pair, rng, proposal_sd_pair ); - } else if (update_method == "hamiltonian-mc") { + } else if (update_method == hamiltonian_mc) { update_hmc_bgmcompare( main_effects, pairwise_effects, main_effect_indices, pairwise_effect_indices, inclusion_indicator, projection, num_categories, @@ -1444,7 +1442,7 @@ void gibbs_update_step_bgmcompare ( main_beta, hmc_nuts_leapfrogs, iteration, hmc_adapt, learn_mass_matrix, schedule.selection_enabled(iteration), rng ); - } else if (update_method == "nuts") { + } else if (update_method == nuts) { SamplerResult result = update_nuts_bgmcompare( main_effects, pairwise_effects, main_effect_indices, pairwise_effect_indices, inclusion_indicator, projection, num_categories, @@ -1577,7 +1575,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare( const arma::imat& interaction_index_matrix, arma::mat inclusion_probability, SafeRNG& rng, - const std::string& update_method, + const UpdateMethod update_method, const int hmc_num_leapfrogs, ProgressManager& pm ) { @@ -1618,7 +1616,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare( // --- Optional HMC/NUTS warmup stage double initial_step_size = 1.0; - if (update_method == "hamiltonian-mc" || update_method == "nuts") { + if (update_method == hamiltonian_mc || update_method == nuts) { initial_step_size = find_initial_stepsize_bgmcompare( main_effects, pairwise_effects, main_effect_indices, pairwise_effect_indices, inclusion_indicator, projection, num_categories, diff --git a/src/bgmCompare_sampler.h b/src/bgmCompare_sampler.h index ec4eebc3..781d7f0d 100644 --- a/src/bgmCompare_sampler.h +++ b/src/bgmCompare_sampler.h @@ -1,6 +1,7 @@ #pragma once #include +#include "common_helpers.h" #include struct SamplerOutput; @@ -40,7 +41,7 @@ SamplerOutput run_gibbs_sampler_bgmCompare( const arma::imat& interaction_index_matrix, arma::mat inclusion_probability, SafeRNG& rng, - const std::string& update_method, + const UpdateMethod update_method, const int hmc_num_leapfrogs, ProgressManager& pm ); \ No newline at end of file diff --git a/src/bgm_logp_and_grad.cpp b/src/bgm_logp_and_grad.cpp index a5e06fd8..5f01958c 100644 --- a/src/bgm_logp_and_grad.cpp +++ b/src/bgm_logp_and_grad.cpp @@ -79,17 +79,17 @@ double log_pseudoposterior_main_effects_component ( // - main_effect_param_c is the main_effect parameter for category c (0-based) arma::vec residual_score = residual_matrix.col (variable); // rest scores for all persons arma::vec bound = num_cats * residual_score; // numerical bound vector - arma::vec denom = arma::exp (-bound); // initialize with base term + arma::vec denom = ARMA_MY_EXP (-bound); // initialize with base term arma::vec main_effect_param = main_effects.row (variable).cols (0, num_cats - 1).t (); // main_effect parameters for (int cat = 0; cat < num_cats; cat++) { arma::vec exponent = main_effect_param(cat) + (cat + 1) * residual_score - bound; // exponent per person - denom += arma::exp (exponent); // accumulate exp terms + denom += ARMA_MY_EXP (exponent); // accumulate exp terms } // We then compute the total log-likelihood contribution as: // log_posterior -= bound + log (denom), summed over all persons - log_posterior -= arma::accu (bound + arma::log (denom)); // total contribution + log_posterior -= arma::accu (bound + ARMA_MY_LOG (denom)); // total contribution } else { const double value = main_effects(variable, parameter); const double linear_main_effect = main_effects(variable, 0); @@ -116,12 +116,12 @@ double log_pseudoposterior_main_effects_component ( double lin_term = linear_main_effect * cat; // precompute linear term arma::vec exponent = lin_term + quad_term + cat * residual_score - bound; - denom += arma::exp (exponent); // accumulate over categories + denom += ARMA_MY_EXP (exponent); // accumulate over categories } // The final log-likelihood contribution is then: // log_posterior -= bound + log (denom), summed over all persons - log_posterior -= arma::accu (bound + arma::log (denom)); // total contribution + log_posterior -= arma::accu (bound + ARMA_MY_LOG (denom)); // total contribution } return log_posterior; @@ -187,10 +187,10 @@ double log_pseudoposterior_interactions_component ( if (is_ordinal_variable (var)) { // Ordinal variable: denominator includes exp (-bounds) - denominator += arma::exp (-bounds); + denominator += ARMA_MY_EXP (-bounds); for (int category = 0; category < num_categories_var; category++) { arma::vec exponent = main_effects (var, category) + (category + 1) * residual_scores - bounds; - denominator += arma::exp(exponent); + denominator += ARMA_MY_EXP(exponent); } } else { @@ -201,12 +201,12 @@ double log_pseudoposterior_interactions_component ( double lin_term = main_effects (var, 0) * category; double quad_term = main_effects (var, 1) * centered_cat * centered_cat; arma::vec exponent = lin_term + quad_term + category * residual_scores - bounds; - denominator += arma::exp (exponent); + denominator += ARMA_MY_EXP (exponent); } } // Subtract log partition function and bounds adjustment - log_pseudo_posterior -= arma::accu (arma::log (denominator)); + log_pseudo_posterior -= arma::accu (ARMA_MY_LOG (denominator)); log_pseudo_posterior -= arma::accu (bounds); } @@ -319,11 +319,11 @@ double log_pseudoposterior ( arma::vec denom; if (is_ordinal_variable(variable)) { - denom = arma::exp (-bound); // initialize with base term + denom = ARMA_MY_EXP (-bound); // initialize with base term arma::vec main_effect_param = main_effects.row (variable).cols (0, num_cats - 1).t (); // main_effect parameters for variable for (int cat = 0; cat < num_cats; cat++) { arma::vec exponent = main_effect_param(cat) + (cat + 1) * residual_score - bound; // exponent per person - denom += arma::exp (exponent); // accumulate exp terms + denom += ARMA_MY_EXP (exponent); // accumulate exp terms } } else { const double lin_effect = main_effects(variable, 0); @@ -336,11 +336,11 @@ double log_pseudoposterior ( double quad = quad_effect * centered * centered; // precompute quadratic term double lin = lin_effect * cat; // precompute linear term arma::vec exponent = lin + quad + cat * residual_score - bound; - denom += arma::exp (exponent); // accumulate over categories + denom += ARMA_MY_EXP (exponent); // accumulate over categories } } - log_pseudoposterior -= arma::accu (bound + arma::log (denom)); // total contribution + log_pseudoposterior -= arma::accu (bound + ARMA_MY_LOG (denom)); // total contribution } return log_pseudoposterior; @@ -463,8 +463,8 @@ arma::vec gradient_log_pseudoposterior ( exponents.col(cat) = main_effect_param(cat) + (cat + 1) * residual_score - bound; } - arma::mat probs = arma::exp (exponents); - arma::vec denom = arma::sum(probs, 1) + arma::exp (-bound); + arma::mat probs = ARMA_MY_EXP (exponents); + arma::vec denom = arma::sum(probs, 1) + ARMA_MY_EXP (-bound); probs.each_col() /= denom; // Expected sufficient statistics main effects @@ -499,7 +499,7 @@ arma::vec gradient_log_pseudoposterior ( double quad = quad_effect * centered * centered; exponents.col(cat) = lin + quad + score * residual_score - bound; } - arma::mat probs = arma::exp (exponents); + arma::mat probs = ARMA_MY_EXP (exponents); arma::vec denom = arma::sum(probs, 1); probs.each_col() /= denom; @@ -712,8 +712,8 @@ double compute_log_likelihood_ratio_for_variable ( arma::vec denom_proposed = arma::zeros (num_persons); if (is_ordinal_variable (variable)) { - denom_current += arma::exp(-bounds); - denom_proposed += arma::exp(-bounds); + denom_current += ARMA_MY_EXP(-bounds); + denom_proposed += ARMA_MY_EXP(-bounds); for (int category = 0; category < num_categories_var; category++) { const double main = main_effects(variable, category); @@ -739,13 +739,13 @@ double compute_log_likelihood_ratio_for_variable ( double quad_term = main_effects (variable, 1) * centered * centered; arma::vec exponent = lin_term + quad_term + category * residual_scores - bounds; - denom_current += arma::exp (exponent + category * interaction * current_state); - denom_proposed += arma::exp (exponent + category * interaction * proposed_state); + denom_current += ARMA_MY_EXP (exponent + category * interaction * current_state); + denom_proposed += ARMA_MY_EXP (exponent + category * interaction * proposed_state); } } // Accumulated log-likelihood difference across persons - return arma::accu (arma::log (denom_current) - arma::log (denom_proposed)); + return arma::accu (ARMA_MY_LOG (denom_current) - ARMA_MY_LOG (denom_proposed)); } diff --git a/src/bgm_parallel.cpp b/src/bgm_parallel.cpp index 7172d165..ed4d94bf 100644 --- a/src/bgm_parallel.cpp +++ b/src/bgm_parallel.cpp @@ -8,6 +8,7 @@ #include #include "progress_manager.h" #include "mcmc_adaptation.h" +#include "common_helpers.h" using namespace Rcpp; using namespace RcppParallel; @@ -61,7 +62,7 @@ struct GibbsChainRunner : public Worker { const arma::imat& observations; const arma::ivec& num_categories; double pairwise_scale; - const std::string& edge_prior; + const EdgePrior edge_prior; const arma::mat& inclusion_probability; double beta_bernoulli_alpha; double beta_bernoulli_beta; @@ -79,7 +80,7 @@ struct GibbsChainRunner : public Worker { const arma::uvec& is_ordinal_variable; const arma::ivec& baseline_category; bool edge_selection; - const std::string& update_method; + const UpdateMethod update_method; const arma::imat& pairwise_effect_indices; double target_accept; const arma::imat& pairwise_stats; @@ -98,7 +99,7 @@ struct GibbsChainRunner : public Worker { const arma::imat& observations, const arma::ivec& num_categories, double pairwise_scale, - const std::string& edge_prior, + const EdgePrior edge_prior, const arma::mat& inclusion_probability, double beta_bernoulli_alpha, double beta_bernoulli_beta, @@ -116,7 +117,7 @@ struct GibbsChainRunner : public Worker { const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, - const std::string& update_method, + const UpdateMethod update_method, const arma::imat& pairwise_effect_indices, double target_accept, const arma::imat& pairwise_stats, @@ -308,18 +309,20 @@ Rcpp::List run_bgm_parallel( chain_rngs[c] = SafeRNG(seed + c); } + UpdateMethod update_method_enum = update_method_from_string(update_method); + EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior); // only used to determine the total no. warmup iterations, a bit hacky - WarmupSchedule warmup_schedule_temp(warmup, edge_selection, (update_method != "adaptive-metropolis")); + WarmupSchedule warmup_schedule_temp(warmup, edge_selection, (update_method_enum != adaptive_metropolis)); int total_warmup = warmup_schedule_temp.total_warmup; ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type); GibbsChainRunner worker( - observations, num_categories, pairwise_scale, edge_prior, + observations, num_categories, pairwise_scale, edge_prior_enum, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, - edge_selection, update_method, pairwise_effect_indices, target_accept, + edge_selection, update_method_enum, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, chain_rngs, pm, results ); diff --git a/src/bgm_sampler.cpp b/src/bgm_sampler.cpp index 1c187ef0..72ff24ba 100644 --- a/src/bgm_sampler.cpp +++ b/src/bgm_sampler.cpp @@ -1015,7 +1015,7 @@ void gibbs_update_step_bgm ( const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const int iteration, - const std::string& update_method, + const UpdateMethod update_method, const arma::imat& pairwise_effect_indices, arma::imat& pairwise_stats, const int hmc_num_leapfrogs, @@ -1051,7 +1051,7 @@ void gibbs_update_step_bgm ( } // Step 2a: Update interaction weights for active edges - if (update_method == "adaptive-metropolis") { + if (update_method == adaptive_metropolis) { update_pairwise_effects_metropolis_bgm ( pairwise_effects, main_effects, inclusion_indicator, observations, num_categories, proposal_sd_pairwise, adapt_pairwise, pairwise_scale, @@ -1061,7 +1061,7 @@ void gibbs_update_step_bgm ( } // Step 2b: Update main effect (main_effect) parameters - if (update_method == "adaptive-metropolis") { + if (update_method == adaptive_metropolis) { update_main_effects_metropolis_bgm ( main_effects, observations, num_categories, counts_per_category, blume_capel_stats, baseline_category, is_ordinal_variable, @@ -1072,7 +1072,7 @@ void gibbs_update_step_bgm ( } // Step 2: Update joint parameters if applicable - if (update_method == "hamiltonian-mc") { + if (update_method == hamiltonian_mc) { update_hmc_bgm( main_effects, pairwise_effects, inclusion_indicator, observations, num_categories, counts_per_category, blume_capel_stats, @@ -1081,7 +1081,7 @@ void gibbs_update_step_bgm ( iteration, adapt, learn_mass_matrix, schedule.selection_enabled(iteration), rng ); - } else if (update_method == "nuts") { + } else if (update_method == nuts) { SamplerResult result = update_nuts_bgm( main_effects, pairwise_effects, inclusion_indicator, observations, num_categories, counts_per_category, blume_capel_stats, @@ -1171,7 +1171,7 @@ Rcpp::List run_gibbs_sampler_bgm( arma::imat observations, const arma::ivec& num_categories, const double pairwise_scale, - const std::string& edge_prior, + const EdgePrior edge_prior, arma::mat inclusion_probability, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, @@ -1189,7 +1189,7 @@ Rcpp::List run_gibbs_sampler_bgm( const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, - const std::string& update_method, + const UpdateMethod update_method, const arma::imat pairwise_effect_indices, const double target_accept, arma::imat pairwise_stats, @@ -1223,7 +1223,7 @@ Rcpp::List run_gibbs_sampler_bgm( if (edge_selection) { indicator_samples.set_size(iter, num_pairwise); } - if (edge_selection && edge_prior == "Stochastic-Block") { + if (edge_selection && edge_prior == Stochastic_Block) { allocation_samples.set_size(iter, num_variables); } @@ -1245,7 +1245,7 @@ Rcpp::List run_gibbs_sampler_bgm( arma::vec log_Vn(1); // --- Initialize SBM prior if applicable - if (edge_prior == "Stochastic-Block") { + if (edge_prior == Stochastic_Block) { cluster_allocations[0] = 0; cluster_allocations[1] = 1; for (int i = 2; i < num_variables; i++) { @@ -1273,7 +1273,7 @@ Rcpp::List run_gibbs_sampler_bgm( // --- Optional HMC/NUTS warmup stage double initial_step_size_joint = 1.0; - if (update_method == "hamiltonian-mc" || update_method == "nuts") { + if (update_method == hamiltonian_mc || update_method == nuts) { initial_step_size_joint = find_initial_stepsize_bgm( main_effects, pairwise_effects, inclusion_indicator, observations, num_categories, counts_per_category, blume_capel_stats, @@ -1283,7 +1283,7 @@ Rcpp::List run_gibbs_sampler_bgm( } // --- Warmup scheduling + adaptation controller - WarmupSchedule warmup_schedule(warmup, edge_selection, (update_method != "adaptive-metropolis")); + WarmupSchedule warmup_schedule(warmup, edge_selection, (update_method != adaptive_metropolis)); HMCAdaptationController adapt_joint( num_main + num_pairwise, initial_step_size_joint, target_accept, warmup_schedule, learn_mass_matrix @@ -1339,7 +1339,7 @@ Rcpp::List run_gibbs_sampler_bgm( // --- Update edge probabilities under the prior (if edge selection is active) if (warmup_schedule.selection_enabled(iteration)) { - if (edge_prior == "Beta-Bernoulli") { + if (edge_prior == Beta_Bernoulli) { int num_edges_included = 0; for (int i = 0; i < num_variables - 1; i++) for (int j = i + 1; j < num_variables; j++) @@ -1354,7 +1354,7 @@ Rcpp::List run_gibbs_sampler_bgm( for (int j = i + 1; j < num_variables; j++) inclusion_probability(i, j) = inclusion_probability(j, i) = prob; - } else if (edge_prior == "Stochastic-Block") { + } else if (edge_prior == Stochastic_Block) { cluster_allocations = block_allocations_mfm_sbm( cluster_allocations, num_variables, log_Vn, cluster_prob, arma::conv_to::from(inclusion_indicator), dirichlet_alpha, @@ -1396,7 +1396,7 @@ Rcpp::List run_gibbs_sampler_bgm( } } - if (edge_selection && edge_prior == "Stochastic-Block") { + if (edge_selection && edge_prior == Stochastic_Block) { for (int v = 0; v < num_variables; v++) { allocation_samples(sample_index, v) = cluster_allocations[v] + 1; } @@ -1408,7 +1408,7 @@ Rcpp::List run_gibbs_sampler_bgm( out["main_samples"] = main_effect_samples; out["pairwise_samples"] = pairwise_effect_samples; - if (update_method == "nuts") { + if (update_method == nuts) { out["treedepth__"] = treedepth_samples; out["divergent__"] = divergent_samples; out["energy__"] = energy_samples; @@ -1418,7 +1418,7 @@ Rcpp::List run_gibbs_sampler_bgm( out["indicator_samples"] = indicator_samples; } - if (edge_selection && edge_prior == "Stochastic-Block") { + if (edge_selection && edge_prior == Stochastic_Block) { out["allocations"] = allocation_samples; } diff --git a/src/bgm_sampler.h b/src/bgm_sampler.h index f701fb79..db4b2749 100644 --- a/src/bgm_sampler.h +++ b/src/bgm_sampler.h @@ -1,5 +1,6 @@ #pragma once #include +#include "common_helpers.h" // forward declaration struct SafeRNG; class ProgressManager; @@ -9,7 +10,7 @@ Rcpp::List run_gibbs_sampler_bgm( arma::imat observations, const arma::ivec& num_categories, const double pairwise_scale, - const std::string& edge_prior, + const EdgePrior edge_prior, arma::mat inclusion_probability, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, @@ -27,7 +28,7 @@ Rcpp::List run_gibbs_sampler_bgm( const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool edge_selection, - const std::string& update_method, + const UpdateMethod update_method, const arma::imat pairwise_effect_indices, const double target_accept, arma::imat pairwise_stats, diff --git a/src/common_helpers.h b/src/common_helpers.h index 8d9869ad..6ab67a20 100644 --- a/src/common_helpers.h +++ b/src/common_helpers.h @@ -25,3 +25,36 @@ inline int count_num_main_effects(const arma::ivec& num_categories, } return n_params; } + +enum UpdateMethod { adaptive_metropolis, hamiltonian_mc, nuts }; + +inline UpdateMethod update_method_from_string(const std::string& update_method) { + if (update_method == "adaptive-metropolis") + return adaptive_metropolis; + + if (update_method == "hamiltonian-mc") + return hamiltonian_mc; + + if (update_method == "nuts") + return nuts; + + throw std::invalid_argument("Invalid update_method: " + update_method); +} + +enum EdgePrior { Stochastic_Block, Beta_Bernoulli, Bernoulli, Not_Applicable }; + +inline EdgePrior edge_prior_from_string(const std::string& edge_prior) { + if (edge_prior == "stochastic-block") + return Stochastic_Block; + + if (edge_prior == "Beta-Bernoulli") + return Beta_Bernoulli; + + if (edge_prior == "Bernoulli") + return Bernoulli; + + if (edge_prior == "Not Applicable") + return Not_Applicable; + + throw std::invalid_argument("Invalid edge_prior: " + edge_prior); +} \ No newline at end of file diff --git a/src/custom_exp.cpp b/src/custom_exp.cpp index dfe2c1d7..e6fe3164 100644 --- a/src/custom_exp.cpp +++ b/src/custom_exp.cpp @@ -1,5 +1,5 @@ -#include "Rcpp.h" #include "explog_switch.h" +#include "Rcpp.h" // [[Rcpp::export]] Rcpp::String get_explog_switch() { diff --git a/src/data_simulation.cpp b/src/data_simulation.cpp index b4aa539a..2e50c820 100644 --- a/src/data_simulation.cpp +++ b/src/data_simulation.cpp @@ -1,5 +1,5 @@ -#include #include "explog_switch.h" +#include using namespace Rcpp; // [[Rcpp::export]] diff --git a/src/e_arma_exp.h b/src/e_arma_exp.h new file mode 100644 index 00000000..99caf086 --- /dev/null +++ b/src/e_arma_exp.h @@ -0,0 +1,44 @@ +#ifndef BGMS_EXPLOG_SWITCH_H +#define BGMS_EXPLOG_SWITCH_H + +#include "RcppArmadillo.h" + +double __ieee754_exp(double x); // forward declaration +double __ieee754_log(double x); // forward declaration + +// elementwise exp +template +arma::Mat custom_arma_exp(const arma::Base& X) +{ + arma::Mat Xin = X.get_ref(); + arma::Mat out(Xin.n_rows, Xin.n_cols, arma::fill::none); + + const double* in_mem = Xin.memptr(); + double* out_mem = out.memptr(); + const arma::uword N = Xin.n_elem; + + for (arma::uword i = 0; i < N; ++i) + out_mem[i] = __ieee754_exp(in_mem[i]); + + return out; +} + +// elementwise log +template +arma::Mat custom_arma_log(const arma::Base& X) +{ + arma::Mat Xin = X.get_ref(); + arma::Mat out(Xin.n_rows, Xin.n_cols, arma::fill::none); + + double* out_mem = out.memptr(); + const double* in_mem = Xin.memptr(); + const arma::uword N = Xin.n_elem; + + for (arma::uword i = 0; i < N; ++i) + out_mem[i] = __ieee754_log(in_mem[i]); + + return out; +} + + +#endif diff --git a/src/explog_switch.h b/src/explog_switch.h index e4a7e2fd..9c174cc4 100644 --- a/src/explog_switch.h +++ b/src/explog_switch.h @@ -21,10 +21,14 @@ #if USE_CUSTOM_LOG #include "e_exp.h" +#include "e_arma_exp.h" #define MY_EXP __ieee754_exp #define MY_LOG __ieee754_log +#define ARMA_MY_EXP custom_arma_exp +#define ARMA_MY_LOG custom_arma_log + // TODO: add and use these // #define MY_EXPM1 std::expm1 // #define MY_LOG1P std::log1p @@ -35,6 +39,8 @@ #define MY_EXP std::exp #define MY_LOG std::log +#define ARMA_MY_EXP arma::exp +#define ARMA_MY_LOG arma::log #endif diff --git a/src/progress_manager.cpp b/src/progress_manager.cpp index 1bb72a45..7e13471a 100644 --- a/src/progress_manager.cpp +++ b/src/progress_manager.cpp @@ -77,7 +77,7 @@ void ProgressManager::update(size_t chainId) { void ProgressManager::finish() { - if (progress_type == 0) return; // No progress display + if (progress_type == 0 || needsToExit) return; // No progress display or user interrupt // Mark all chains as complete and print one final time for (size_t i = 0; i < nChains; i++) @@ -187,12 +187,6 @@ std::string ProgressManager::formatProgressBar(size_t chainId, size_t current, s return output; } -// std::string ProgressManager::formatTimeInfo(double elapsed, double eta) { -// std::ostringstream builder; -// builder << "Elapsed: " << elapsed << "s | ETA: " << eta << "s"; -// return builder.str(); -// } - std::string ProgressManager::formatTimeInfo(double elapsed, double eta) const { std::ostringstream builder; builder << "Elapsed: " << formatDuration(elapsed) << " | ETA: " << formatDuration(eta); @@ -302,9 +296,11 @@ void ProgressManager::print() { // totalChars += chainProgress.length() + 1; // +1 for newline } - // Print total progress - std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); - out << totalProgress << "\n"; + // Print total progress if there is more than one chain + if (nChains > 1) { + std::string totalProgress = formatProgressBar(0, done, totalWork, fracTotal, true); + out << totalProgress << "\n"; + } // totalChars += totalProgress.length() + 1; // +1 for newline // Print time info @@ -314,7 +310,9 @@ void ProgressManager::print() { // totalChars += timeInfo.length() + 1; // +1 for newline // Track total lines printed (chains + total + time) - lastPrintedLines = nChains + 2; // used in a generic terminal + + lastPrintedLines = nChains + (nChains > 1 ? 2 : 1); // used in a generic terminal + lastPrintedChars = 1;//totalChars; // used by RStudio } else if (progress_type == 1) {