Skip to content

Commit e98edc2

Browse files
authored
Minor optimizations (#61)
* use enums instead of strings * don't forget Bernoulli prior * small tweaks to the gradient of bgmCompare * fix edge prior enum and small tweak to progress bar * try to fix the errors * use custom exp/ log when calling arma::exp/ log on windows * use logit instead of +log(x)-log(1-x) * fix numerical difference
1 parent a9d2dd4 commit e98edc2

15 files changed

+235
-132
lines changed

src/bgmCompare_logp_and_grad.cpp

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ double log_pseudoposterior(
148148

149149
if (is_ordinal_variable(v)) {
150150
// base term exp(-bound)
151-
denom = arma::exp(-bound);
151+
denom = ARMA_MY_EXP(-bound);
152152
// main_effects from main_group
153153
for (int c = 0; c < num_cats; ++c) {
154154
const double th = main_group(v, c);
155155
const arma::vec exponent = th + (c + 1) * rest_score - bound;
156-
denom += arma::exp(exponent);
156+
denom += ARMA_MY_EXP(exponent);
157157
}
158158
} else {
159159
// linear/quadratic main effects from main_group
@@ -165,11 +165,11 @@ double log_pseudoposterior(
165165
const double quad = quad_effect * centered * centered;
166166
const double lin = lin_effect * c;
167167
const arma::vec exponent = lin + quad + c * rest_score - bound;
168-
denom += arma::exp(exponent);
168+
denom += ARMA_MY_EXP(exponent);
169169
}
170170
}
171171
// - sum_i [ bound_i + log denom_i ]
172-
log_pp -= arma::accu(bound + arma::log(denom));
172+
log_pp -= arma::accu(bound + ARMA_MY_LOG(denom));
173173
}
174174
}
175175

@@ -281,8 +281,8 @@ arma::vec gradient(
281281
const double main_beta,
282282
const double interaction_scale,
283283
const double difference_scale,
284-
const arma::imat main_index,
285-
const arma::imat pair_index
284+
const arma::imat& main_index,
285+
const arma::imat& pair_index
286286
) {
287287
const int num_variables = observations.n_cols;
288288
const int max_num_categories = num_categories.max();
@@ -300,15 +300,21 @@ arma::vec gradient(
300300
total_len += static_cast<long long>(r1 - r0 + 1) * (num_groups - 1);
301301
}
302302
}
303-
for (int v1 = 0; v1 < num_variables - 1; ++v1) {
304-
for (int v2 = v1 + 1; v2 < num_variables; ++v2) {
305-
if (inclusion_indicator(v1, v2) == 1) total_len += (num_groups - 1);
303+
for (int v2 = 0; v2 < num_variables - 1; ++v2) {
304+
for (int v1 = v2 + 1; v1 < num_variables; ++v1) {
305+
total_len += (inclusion_indicator(v1, v2) == 1) * (num_groups - 1);
306306
}
307307
}
308308

309309
arma::vec grad(total_len, arma::fill::zeros);
310310
int off;
311311

312+
// -------------------------------------------------
313+
// Allocate temporaries ONCE (reused inside loops)
314+
// -------------------------------------------------
315+
arma::mat main_group(num_variables, max_num_categories, arma::fill::none);
316+
arma::mat pairwise_group(num_variables, num_variables, arma::fill::none);
317+
312318
// -------------------------------
313319
// Observed sufficient statistics
314320
// -------------------------------
@@ -324,15 +330,14 @@ arma::vec gradient(
324330

325331
if (is_ordinal_variable(v)) {
326332
for (int c = 0; c < num_cats; ++c) {
327-
// overall
328333
off = main_index(base + c, 0);
329334
grad(off) += counts_per_category(c, v);
330335

331-
// diffs
332336
if (inclusion_indicator(v, v) != 0) {
337+
const int cnt = counts_per_category(c, v);
333338
for (int k = 1; k < num_groups; ++k) {
334339
off = main_index(base + c, k);
335-
grad(off) += counts_per_category(c, v) * projection(g, k - 1);
340+
grad(off) += cnt * projection(g, k - 1);
336341
}
337342
}
338343
}
@@ -344,7 +349,6 @@ arma::vec gradient(
344349
off = main_index(base + 1, 0);
345350
grad(off) += blume_capel_stats(1, v);
346351

347-
// diffs
348352
if (inclusion_indicator(v, v) != 0) {
349353
for (int k = 1; k < num_groups; ++k) {
350354
off = main_index(base, k);
@@ -382,9 +386,9 @@ arma::vec gradient(
382386
const int r0 = group_indices(g, 0);
383387
const int r1 = group_indices(g, 1);
384388

385-
arma::mat main_group(num_variables, max_num_categories, arma::fill::zeros);
386-
arma::mat pairwise_group(num_variables, num_variables, arma::fill::zeros);
387389
const arma::vec proj_g = projection.row(g).t(); // length = num_groups-1
390+
main_group.zeros();
391+
pairwise_group.zeros();
388392

389393
// build group-specific params
390394
for (int v = 0; v < num_variables; ++v) {
@@ -393,8 +397,7 @@ arma::vec gradient(
393397
);
394398
main_group(v, arma::span(0, me.n_elem - 1)) = me.t();
395399

396-
for (int u = v; u < num_variables; ++u) { // Combines with loop over v
397-
if(u == v) continue;
400+
for (int u = v + 1; u < num_variables; ++u) {
398401
double w = compute_group_pairwise_effects(
399402
v, u, num_groups, pairwise_effects, pairwise_effect_indices,
400403
inclusion_indicator, proj_g
@@ -415,15 +418,14 @@ arma::vec gradient(
415418

416419
arma::vec rest_score = residual_matrix.col(v);
417420
arma::vec bound = K * rest_score;
418-
bound = arma::clamp(bound, 0.0, arma::datum::inf);
421+
bound.clamp(0.0, arma::datum::inf);
419422

420-
arma::mat exponents(num_group_obs, K + 1, arma::fill::zeros);
423+
arma::mat exponents(num_group_obs, K + 1, arma::fill::none);
421424

422425
if (is_ordinal_variable(v)) {
423-
exponents.col(0) -= bound;
424-
arma::vec main_param = main_group.row(v).cols(0, K - 1).t();
425-
for (int j = 0; j < K; j++) {
426-
exponents.col(j+1) = main_param(j) + (j + 1) * rest_score - bound;
426+
exponents.col(0) = -bound;
427+
for (int j = 0; j < K; ++j) {
428+
exponents.col(j + 1) = main_group(v, j) + (j + 1) * rest_score - bound;
427429
}
428430
} else {
429431
const double lin_effect = main_group(v, 0);
@@ -442,58 +444,72 @@ arma::vec gradient(
442444

443445
// ---- MAIN expected ----
444446
const int base = main_effect_indices(v, 0);
445-
446447
if (is_ordinal_variable(v)) {
447448
for (int s = 1; s <= K; ++s) {
448449
const int j = s - 1;
450+
double sum_col_s = arma::accu(probs.col(s));
451+
452+
449453
off = main_index(base + j, 0);
450-
grad(off) -= arma::accu(probs.col(s));
454+
grad(off) -= sum_col_s;
455+
451456

452457
if (inclusion_indicator(v, v) == 0) continue;
453458
for (int k = 1; k < num_groups; ++k) {
454459
off = main_index(base + j, k);
455-
grad(off) -= projection(g, k - 1) * arma::accu(probs.col(s));
460+
grad(off) -= projection(g, k - 1) * sum_col_s;
461+
456462
}
457463
}
458464
} else {
465+
459466
arma::vec lin_score = arma::regspace<arma::vec>(0, K); // length K+1
460467
arma::vec quad_score = arma::square(lin_score - ref);
461468

462-
off = main_index(base, 0);
463-
grad(off) -= arma::accu(probs * lin_score);
469+
double sum_lin = arma::accu(probs * lin_score);
470+
double sum_quad = arma::accu(probs * quad_score);
464471

472+
off = main_index(base, 0);
473+
grad(off) -= sum_lin;
465474
off = main_index(base + 1, 0);
466-
grad(off) -= arma::accu(probs * quad_score);
475+
grad(off) -= sum_quad;
467476

468477
if (inclusion_indicator(v, v) == 0) continue;
469478
for (int k = 1; k < num_groups; ++k) {
470479
off = main_index(base, k);
471-
grad(off) -= projection(g, k - 1) * arma::accu(probs * lin_score);
480+
grad(off) -= projection(g, k - 1) * sum_lin;
481+
472482
off = main_index(base + 1, k);
473-
grad(off) -= projection(g, k - 1) * arma::accu(probs * quad_score);
483+
grad(off) -= projection(g, k - 1) * sum_quad;
474484
}
475485
}
476486

477487
// ---- PAIRWISE expected ----
478488
for (int v2 = 0; v2 < num_variables; ++v2) {
479489
if (v == v2) continue;
480490

481-
const int row = (v < v2) ? pairwise_effect_indices(v, v2)
482-
: pairwise_effect_indices(v2, v);
483-
484491
arma::vec expected_value(num_group_obs, arma::fill::zeros);
485-
486492
for (int s = 1; s <= K; ++s) {
487493
expected_value += s * probs.col(s) % obs.col(v2);
488494
}
495+
double sum_expectation = arma::accu(expected_value);
496+
497+
// this is mathematically equivalent but numerically different...
498+
// double sum_expectation = 0.0;
499+
// for (int s = 1; s <= K; ++s) {
500+
// sum_expectation += s * arma::dot(probs.col(s), obs.col(v2));
501+
// }
502+
503+
const int row = (v < v2) ? pairwise_effect_indices(v, v2)
504+
: pairwise_effect_indices(v2, v);
489505

490506
off = pair_index(row, 0);
491-
grad(off) -= arma::accu(expected_value);
507+
grad(off) -= sum_expectation;
492508

493509
if (inclusion_indicator(v, v2) == 0) continue;
494510
for (int k = 1; k < num_groups; ++k) {
495511
off = pair_index(row, k);
496-
grad(off) -= projection(g, k - 1) * arma::accu(expected_value);
512+
grad(off) -= projection(g, k - 1) * sum_expectation;
497513
}
498514
}
499515
}
@@ -706,12 +722,12 @@ double log_pseudoposterior_main_component(
706722
arma::vec denom(rest_score.n_elem, arma::fill::zeros);
707723
if (is_ordinal_variable(variable)) {
708724
// base term exp(-bound)
709-
denom = arma::exp(-bound);
725+
denom = ARMA_MY_EXP(-bound);
710726
// main_effects from main_group
711727
for (int cat = 0; cat < num_cats; cat++) {
712728
const double th = main_group(variable, cat);
713729
const arma::vec exponent = th + (cat + 1) * rest_score - bound;
714-
denom += arma::exp(exponent);
730+
denom += ARMA_MY_EXP(exponent);
715731
}
716732
} else {
717733
// linear/quadratic main effects from main_group
@@ -723,11 +739,11 @@ double log_pseudoposterior_main_component(
723739
const double quad = quad_effect * centered * centered;
724740
const double lin = lin_effect * cat;
725741
const arma::vec exponent = lin + quad + cat * rest_score - bound;
726-
denom += arma::exp(exponent);
742+
denom += ARMA_MY_EXP(exponent);
727743
}
728744
}
729745
// - sum_i [ bound_i + log denom_i ]
730-
log_pp -= arma::accu(bound + arma::log(denom));
746+
log_pp -= arma::accu(bound + ARMA_MY_LOG(denom));
731747
}
732748

733749
// ---- priors ----
@@ -890,12 +906,12 @@ double log_pseudoposterior_pair_component(
890906

891907
if (is_ordinal_variable(v)) {
892908
// base term exp(-bound)
893-
denom = arma::exp(-bound);
909+
denom = ARMA_MY_EXP(-bound);
894910
// main_effects from main_group
895911
for (int c = 0; c < num_cats; ++c) {
896912
const double th = main_group(v, c);
897913
const arma::vec exponent = th + (c + 1) * rest_score - bound;
898-
denom += arma::exp(exponent);
914+
denom += ARMA_MY_EXP(exponent);
899915
}
900916
} else {
901917
// linear/quadratic main effects from main_group
@@ -907,11 +923,11 @@ double log_pseudoposterior_pair_component(
907923
const double quad = quad_effect * centered * centered;
908924
const double lin = lin_effect * c;
909925
const arma::vec exponent = lin + quad + c * rest_score - bound;
910-
denom += arma::exp(exponent);
926+
denom += ARMA_MY_EXP(exponent);
911927
}
912928
}
913929
// - sum_i [ bound_i + log denom_i ]
914-
log_pp -= arma::accu(bound + arma::log(denom));
930+
log_pp -= arma::accu(bound + ARMA_MY_LOG(denom));
915931
}
916932
}
917933

@@ -1037,12 +1053,12 @@ double log_ratio_pseudolikelihood_constant_variable(
10371053
bound_current = num_cats * arma::clamp(rest_current, 0.0, arma::datum::inf);
10381054
bound_proposed = num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf);
10391055

1040-
denom_current = arma::exp(-bound_current);
1041-
denom_proposed = arma::exp(-bound_proposed);
1056+
denom_current = ARMA_MY_EXP(-bound_current);
1057+
denom_proposed = ARMA_MY_EXP(-bound_proposed);
10421058

10431059
for (int c = 0; c < num_cats; ++c) {
1044-
denom_current += arma::exp(main_current(c) + (c + 1) * rest_current - bound_current);
1045-
denom_proposed += arma::exp(main_proposed(c) + (c + 1) * rest_proposed - bound_proposed);
1060+
denom_current += ARMA_MY_EXP(main_current(c) + (c + 1) * rest_current - bound_current);
1061+
denom_proposed += ARMA_MY_EXP(main_proposed(c) + (c + 1) * rest_proposed - bound_proposed);
10461062
}
10471063
} else {
10481064
// Blume-Capel: linear + quadratic
@@ -1063,14 +1079,14 @@ double log_ratio_pseudolikelihood_constant_variable(
10631079
bound_proposed = lbound + num_cats * arma::clamp(rest_proposed, 0.0, arma::datum::inf);
10641080

10651081
for (int s = 0; s <= num_cats; ++s) {
1066-
denom_current += arma::exp(const_current(s) + s * rest_current - bound_current);
1067-
denom_proposed += arma::exp(const_proposed(s) + s * rest_proposed - bound_proposed);
1082+
denom_current += ARMA_MY_EXP(const_current(s) + s * rest_current - bound_current);
1083+
denom_proposed += ARMA_MY_EXP(const_proposed(s) + s * rest_proposed - bound_proposed);
10681084
}
10691085
}
10701086

10711087
// --- accumulate contribution ---
10721088
log_ratio += arma::accu((bound_current - bound_proposed) +
1073-
arma::log(denom_current) - arma::log(denom_proposed));
1089+
ARMA_MY_LOG(denom_current) - ARMA_MY_LOG(denom_proposed));
10741090
}
10751091

10761092
return log_ratio;

src/bgmCompare_logp_and_grad.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ arma::vec gradient(
4646
const double main_beta,
4747
const double interaction_scale,
4848
const double difference_scale,
49-
const arma::imat main_index,
50-
const arma::imat pair_index
49+
const arma::imat& main_index,
50+
const arma::imat& pair_index
5151
);
5252

5353
double log_pseudoposterior_main_component(

src/bgmCompare_parallel.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "progress_manager.h"
1010
#include "sampler_output.h"
1111
#include "mcmc_adaptation.h"
12+
#include "common_helpers.h"
1213

1314
using namespace Rcpp;
1415
using namespace RcppParallel;
@@ -131,7 +132,7 @@ struct GibbsCompareChainRunner : public Worker {
131132
const arma::mat& inclusion_probability_master;
132133
// RNG seeds
133134
const std::vector<SafeRNG>& chain_rngs;
134-
const std::string& update_method;
135+
const UpdateMethod update_method;
135136
const int hmc_num_leapfrogs;
136137
ProgressManager& pm;
137138
// output
@@ -169,7 +170,7 @@ struct GibbsCompareChainRunner : public Worker {
169170
const arma::imat& interaction_index_matrix,
170171
const arma::mat& inclusion_probability_master,
171172
const std::vector<SafeRNG>& chain_rngs,
172-
const std::string& update_method,
173+
const UpdateMethod update_method,
173174
const int hmc_num_leapfrogs,
174175
ProgressManager& pm,
175176
std::vector<ChainResultCompare>& results
@@ -394,8 +395,10 @@ Rcpp::List run_bgmCompare_parallel(
394395
chain_rngs[c] = SafeRNG(seed + c);
395396
}
396397

398+
UpdateMethod update_method_enum = update_method_from_string(update_method);
399+
397400
// only used to determine the total no. warmup iterations, a bit hacky
398-
WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method != "adaptive-metropolis"));
401+
WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method_enum != adaptive_metropolis));
399402
int total_warmup = warmup_schedule_temp.total_warmup;
400403
ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type);
401404

@@ -408,7 +411,7 @@ Rcpp::List run_bgmCompare_parallel(
408411
baseline_category, difference_selection, main_effect_indices,
409412
pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix,
410413
projection, group_membership, group_indices, interaction_index_matrix,
411-
inclusion_probability, chain_rngs, update_method, hmc_num_leapfrogs,
414+
inclusion_probability, chain_rngs, update_method_enum, hmc_num_leapfrogs,
412415
pm, results
413416
);
414417

0 commit comments

Comments
 (0)