Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 68 additions & 52 deletions src/bgmCompare_logp_and_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ double log_pseudoposterior(

if (is_ordinal_variable(v)) {
// base term exp(-bound)
denom = arma::exp(-bound);
denom = ARMA_MY_EXP(-bound);
// main_effects from main_group
for (int c = 0; c < num_cats; ++c) {
const double th = main_group(v, c);
const arma::vec exponent = th + (c + 1) * rest_score - bound;
denom += arma::exp(exponent);
denom += ARMA_MY_EXP(exponent);
}
} else {
// linear/quadratic main effects from main_group
Expand All @@ -165,11 +165,11 @@ double log_pseudoposterior(
const double quad = quad_effect * centered * centered;
const double lin = lin_effect * c;
const arma::vec exponent = lin + quad + c * rest_score - bound;
denom += arma::exp(exponent);
denom += ARMA_MY_EXP(exponent);
}
}
// - sum_i [ bound_i + log denom_i ]
log_pp -= arma::accu(bound + arma::log(denom));
log_pp -= arma::accu(bound + ARMA_MY_LOG(denom));
}
}

Expand Down Expand Up @@ -281,8 +281,8 @@ arma::vec gradient(
const double main_beta,
const double interaction_scale,
const double difference_scale,
const arma::imat main_index,
const arma::imat pair_index
const arma::imat& main_index,
const arma::imat& pair_index
) {
const int num_variables = observations.n_cols;
const int max_num_categories = num_categories.max();
Expand All @@ -300,15 +300,21 @@ arma::vec gradient(
total_len += static_cast<long long>(r1 - r0 + 1) * (num_groups - 1);
}
}
for (int v1 = 0; v1 < num_variables - 1; ++v1) {
for (int v2 = v1 + 1; v2 < num_variables; ++v2) {
if (inclusion_indicator(v1, v2) == 1) total_len += (num_groups - 1);
for (int v2 = 0; v2 < num_variables - 1; ++v2) {
for (int v1 = v2 + 1; v1 < num_variables; ++v1) {
total_len += (inclusion_indicator(v1, v2) == 1) * (num_groups - 1);
}
}

arma::vec grad(total_len, arma::fill::zeros);
int off;

// -------------------------------------------------
// Allocate temporaries ONCE (reused inside loops)
// -------------------------------------------------
arma::mat main_group(num_variables, max_num_categories, arma::fill::none);
arma::mat pairwise_group(num_variables, num_variables, arma::fill::none);

// -------------------------------
// Observed sufficient statistics
// -------------------------------
Expand All @@ -324,15 +330,14 @@ arma::vec gradient(

if (is_ordinal_variable(v)) {
for (int c = 0; c < num_cats; ++c) {
// overall
off = main_index(base + c, 0);
grad(off) += counts_per_category(c, v);

// diffs
if (inclusion_indicator(v, v) != 0) {
const int cnt = counts_per_category(c, v);
for (int k = 1; k < num_groups; ++k) {
off = main_index(base + c, k);
grad(off) += counts_per_category(c, v) * projection(g, k - 1);
grad(off) += cnt * projection(g, k - 1);
}
}
}
Expand All @@ -344,7 +349,6 @@ arma::vec gradient(
off = main_index(base + 1, 0);
grad(off) += blume_capel_stats(1, v);

// diffs
if (inclusion_indicator(v, v) != 0) {
for (int k = 1; k < num_groups; ++k) {
off = main_index(base, k);
Expand Down Expand Up @@ -382,9 +386,9 @@ arma::vec gradient(
const int r0 = group_indices(g, 0);
const int r1 = group_indices(g, 1);

arma::mat main_group(num_variables, max_num_categories, arma::fill::zeros);
arma::mat pairwise_group(num_variables, num_variables, arma::fill::zeros);
const arma::vec proj_g = projection.row(g).t(); // length = num_groups-1
main_group.zeros();
pairwise_group.zeros();

// build group-specific params
for (int v = 0; v < num_variables; ++v) {
Expand All @@ -393,8 +397,7 @@ arma::vec gradient(
);
main_group(v, arma::span(0, me.n_elem - 1)) = me.t();

for (int u = v; u < num_variables; ++u) { // Combines with loop over v
if(u == v) continue;
for (int u = v + 1; u < num_variables; ++u) {
double w = compute_group_pairwise_effects(
v, u, num_groups, pairwise_effects, pairwise_effect_indices,
inclusion_indicator, proj_g
Expand All @@ -415,15 +418,14 @@ arma::vec gradient(

arma::vec rest_score = residual_matrix.col(v);
arma::vec bound = K * rest_score;
bound = arma::clamp(bound, 0.0, arma::datum::inf);
bound.clamp(0.0, arma::datum::inf);

arma::mat exponents(num_group_obs, K + 1, arma::fill::zeros);
arma::mat exponents(num_group_obs, K + 1, arma::fill::none);

if (is_ordinal_variable(v)) {
exponents.col(0) -= bound;
arma::vec main_param = main_group.row(v).cols(0, K - 1).t();
for (int j = 0; j < K; j++) {
exponents.col(j+1) = main_param(j) + (j + 1) * rest_score - bound;
exponents.col(0) = -bound;
for (int j = 0; j < K; ++j) {
exponents.col(j + 1) = main_group(v, j) + (j + 1) * rest_score - bound;
}
} else {
const double lin_effect = main_group(v, 0);
Expand All @@ -442,58 +444,72 @@ arma::vec gradient(

// ---- MAIN expected ----
const int base = main_effect_indices(v, 0);

if (is_ordinal_variable(v)) {
for (int s = 1; s <= K; ++s) {
const int j = s - 1;
double sum_col_s = arma::accu(probs.col(s));


off = main_index(base + j, 0);
grad(off) -= arma::accu(probs.col(s));
grad(off) -= sum_col_s;


if (inclusion_indicator(v, v) == 0) continue;
for (int k = 1; k < num_groups; ++k) {
off = main_index(base + j, k);
grad(off) -= projection(g, k - 1) * arma::accu(probs.col(s));
grad(off) -= projection(g, k - 1) * sum_col_s;

}
}
} else {

arma::vec lin_score = arma::regspace<arma::vec>(0, K); // length K+1
arma::vec quad_score = arma::square(lin_score - ref);

off = main_index(base, 0);
grad(off) -= arma::accu(probs * lin_score);
double sum_lin = arma::accu(probs * lin_score);
double sum_quad = arma::accu(probs * quad_score);

off = main_index(base, 0);
grad(off) -= sum_lin;
off = main_index(base + 1, 0);
grad(off) -= arma::accu(probs * quad_score);
grad(off) -= sum_quad;

if (inclusion_indicator(v, v) == 0) continue;
for (int k = 1; k < num_groups; ++k) {
off = main_index(base, k);
grad(off) -= projection(g, k - 1) * arma::accu(probs * lin_score);
grad(off) -= projection(g, k - 1) * sum_lin;

off = main_index(base + 1, k);
grad(off) -= projection(g, k - 1) * arma::accu(probs * quad_score);
grad(off) -= projection(g, k - 1) * sum_quad;
}
}

// ---- PAIRWISE expected ----
for (int v2 = 0; v2 < num_variables; ++v2) {
if (v == v2) continue;

const int row = (v < v2) ? pairwise_effect_indices(v, v2)
: pairwise_effect_indices(v2, v);

arma::vec expected_value(num_group_obs, arma::fill::zeros);

for (int s = 1; s <= K; ++s) {
expected_value += s * probs.col(s) % obs.col(v2);
}
double sum_expectation = arma::accu(expected_value);

// this is mathematically equivalent but numerically different...
// double sum_expectation = 0.0;
// for (int s = 1; s <= K; ++s) {
// sum_expectation += s * arma::dot(probs.col(s), obs.col(v2));
// }

const int row = (v < v2) ? pairwise_effect_indices(v, v2)
: pairwise_effect_indices(v2, v);

off = pair_index(row, 0);
grad(off) -= arma::accu(expected_value);
grad(off) -= sum_expectation;

if (inclusion_indicator(v, v2) == 0) continue;
for (int k = 1; k < num_groups; ++k) {
off = pair_index(row, k);
grad(off) -= projection(g, k - 1) * arma::accu(expected_value);
grad(off) -= projection(g, k - 1) * sum_expectation;
}
}
}
Expand Down Expand Up @@ -706,12 +722,12 @@ double log_pseudoposterior_main_component(
arma::vec denom(rest_score.n_elem, arma::fill::zeros);
if (is_ordinal_variable(variable)) {
// base term exp(-bound)
denom = arma::exp(-bound);
denom = ARMA_MY_EXP(-bound);
// main_effects from main_group
for (int cat = 0; cat < num_cats; cat++) {
const double th = main_group(variable, cat);
const arma::vec exponent = th + (cat + 1) * rest_score - bound;
denom += arma::exp(exponent);
denom += ARMA_MY_EXP(exponent);
}
} else {
// linear/quadratic main effects from main_group
Expand All @@ -723,11 +739,11 @@ double log_pseudoposterior_main_component(
const double quad = quad_effect * centered * centered;
const double lin = lin_effect * cat;
const arma::vec exponent = lin + quad + cat * rest_score - bound;
denom += arma::exp(exponent);
denom += ARMA_MY_EXP(exponent);
}
}
// - sum_i [ bound_i + log denom_i ]
log_pp -= arma::accu(bound + arma::log(denom));
log_pp -= arma::accu(bound + ARMA_MY_LOG(denom));
}

// ---- priors ----
Expand Down Expand Up @@ -890,12 +906,12 @@ double log_pseudoposterior_pair_component(

if (is_ordinal_variable(v)) {
// base term exp(-bound)
denom = arma::exp(-bound);
denom = ARMA_MY_EXP(-bound);
// main_effects from main_group
for (int c = 0; c < num_cats; ++c) {
const double th = main_group(v, c);
const arma::vec exponent = th + (c + 1) * rest_score - bound;
denom += arma::exp(exponent);
denom += ARMA_MY_EXP(exponent);
}
} else {
// linear/quadratic main effects from main_group
Expand All @@ -907,11 +923,11 @@ double log_pseudoposterior_pair_component(
const double quad = quad_effect * centered * centered;
const double lin = lin_effect * c;
const arma::vec exponent = lin + quad + c * rest_score - bound;
denom += arma::exp(exponent);
denom += ARMA_MY_EXP(exponent);
}
}
// - sum_i [ bound_i + log denom_i ]
log_pp -= arma::accu(bound + arma::log(denom));
log_pp -= arma::accu(bound + ARMA_MY_LOG(denom));
}
}

Expand Down Expand Up @@ -1037,12 +1053,12 @@ double log_ratio_pseudolikelihood_constant_variable(
bound_current = num_cats * arma::clamp(rest_current, 0.0, arma::datum::inf);
bound_proposed = num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf);

denom_current = arma::exp(-bound_current);
denom_proposed = arma::exp(-bound_proposed);
denom_current = ARMA_MY_EXP(-bound_current);
denom_proposed = ARMA_MY_EXP(-bound_proposed);

for (int c = 0; c < num_cats; ++c) {
denom_current += arma::exp(main_current(c) + (c + 1) * rest_current - bound_current);
denom_proposed += arma::exp(main_proposed(c) + (c + 1) * rest_proposed - bound_proposed);
denom_current += ARMA_MY_EXP(main_current(c) + (c + 1) * rest_current - bound_current);
denom_proposed += ARMA_MY_EXP(main_proposed(c) + (c + 1) * rest_proposed - bound_proposed);
}
} else {
// Blume-Capel: linear + quadratic
Expand All @@ -1063,14 +1079,14 @@ double log_ratio_pseudolikelihood_constant_variable(
bound_proposed = lbound + num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf);

for (int s = 0; s <= num_cats; ++s) {
denom_current += arma::exp(const_current(s) + s * rest_current - bound_current);
denom_proposed += arma::exp(const_proposed(s) + s * rest_proposed - bound_proposed);
denom_current += ARMA_MY_EXP(const_current(s) + s * rest_current - bound_current);
denom_proposed += ARMA_MY_EXP(const_proposed(s) + s * rest_proposed - bound_proposed);
}
}

// --- accumulate contribution ---
log_ratio += arma::accu((bound_current - bound_proposed) +
arma::log(denom_current) - arma::log(denom_proposed));
ARMA_MY_LOG(denom_current) - ARMA_MY_LOG(denom_proposed));
}

return log_ratio;
Expand Down
4 changes: 2 additions & 2 deletions src/bgmCompare_logp_and_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ arma::vec gradient(
const double main_beta,
const double interaction_scale,
const double difference_scale,
const arma::imat main_index,
const arma::imat pair_index
const arma::imat& main_index,
const arma::imat& pair_index
);

double log_pseudoposterior_main_component(
Expand Down
11 changes: 7 additions & 4 deletions src/bgmCompare_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "progress_manager.h"
#include "sampler_output.h"
#include "mcmc_adaptation.h"
#include "common_helpers.h"

using namespace Rcpp;
using namespace RcppParallel;
Expand Down Expand Up @@ -131,7 +132,7 @@ struct GibbsCompareChainRunner : public Worker {
const arma::mat& inclusion_probability_master;
// RNG seeds
const std::vector<SafeRNG>& chain_rngs;
const std::string& update_method;
const UpdateMethod update_method;
const int hmc_num_leapfrogs;
ProgressManager& pm;
// output
Expand Down Expand Up @@ -169,7 +170,7 @@ struct GibbsCompareChainRunner : public Worker {
const arma::imat& interaction_index_matrix,
const arma::mat& inclusion_probability_master,
const std::vector<SafeRNG>& chain_rngs,
const std::string& update_method,
const UpdateMethod update_method,
const int hmc_num_leapfrogs,
ProgressManager& pm,
std::vector<ChainResultCompare>& results
Expand Down Expand Up @@ -394,8 +395,10 @@ Rcpp::List run_bgmCompare_parallel(
chain_rngs[c] = SafeRNG(seed + c);
}

UpdateMethod update_method_enum = update_method_from_string(update_method);

// only used to determine the total no. warmup iterations, a bit hacky
WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method != "adaptive-metropolis"));
WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method_enum != adaptive_metropolis));
int total_warmup = warmup_schedule_temp.total_warmup;
ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type);

Expand All @@ -408,7 +411,7 @@ Rcpp::List run_bgmCompare_parallel(
baseline_category, difference_selection, main_effect_indices,
pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix,
projection, group_membership, group_indices, interaction_index_matrix,
inclusion_probability, chain_rngs, update_method, hmc_num_leapfrogs,
inclusion_probability, chain_rngs, update_method_enum, hmc_num_leapfrogs,
pm, results
);

Expand Down
Loading
Loading