Skip to content

Commit 4218cc0

Browse files
committed
small tweaks to the gradient of bgmCompare
1 parent 7924be1 commit 4218cc0

File tree

2 files changed

+67
-49
lines changed

2 files changed

+67
-49
lines changed

src/bgmCompare_logp_and_grad.cpp

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ arma::vec gradient(
281281
const double main_beta,
282282
const double interaction_scale,
283283
const double difference_scale,
284-
const arma::imat main_index,
285-
const arma::imat pair_index
284+
const arma::imat& main_index,
285+
const arma::imat& pair_index
286286
) {
287287
const int num_variables = observations.n_cols;
288288
const int max_num_categories = num_categories.max();
@@ -300,22 +300,30 @@ arma::vec gradient(
300300
total_len += static_cast<long long>(r1 - r0 + 1) * (num_groups - 1);
301301
}
302302
}
303-
for (int v1 = 0; v1 < num_variables - 1; ++v1) {
304-
for (int v2 = v1 + 1; v2 < num_variables; ++v2) {
305-
if (inclusion_indicator(v1, v2) == 1) total_len += (num_groups - 1);
303+
for (int v2 = 0; v2 < num_variables - 1; ++v2) {
304+
for (int v1 = v2 + 1; v1 < num_variables; ++v1) {
305+
total_len += (inclusion_indicator(v1, v2) == 1) * (num_groups - 1);
306306
}
307307
}
308308

309309
arma::vec grad(total_len, arma::fill::zeros);
310310
int off;
311311

312+
// -------------------------------------------------
313+
// Allocate temporaries ONCE (reused inside loops)
314+
// -------------------------------------------------
315+
arma::mat main_group(num_variables, max_num_categories, arma::fill::zeros);
316+
arma::mat pairwise_group(num_variables, num_variables, arma::fill::zeros);
317+
arma::mat exponents; // resized per variable
318+
arma::vec lin_score, quad_score;
319+
312320
// -------------------------------
313321
// Observed sufficient statistics
314322
// -------------------------------
315323
for (int g = 0; g < num_groups; ++g) {
316-
// list access
317-
arma::imat counts_per_category = counts_per_category_group[g];
318-
arma::imat blume_capel_stats = blume_capel_stats_group[g];
324+
const arma::imat& counts_per_category = counts_per_category_group[g];
325+
const arma::imat& blume_capel_stats = blume_capel_stats_group[g];
326+
const arma::mat& pairwise_stats = pairwise_stats_group[g];
319327

320328
// Main effects
321329
for (int v = 0; v < num_variables; ++v) {
@@ -324,15 +332,14 @@ arma::vec gradient(
324332

325333
if (is_ordinal_variable(v)) {
326334
for (int c = 0; c < num_cats; ++c) {
327-
// overall
328335
off = main_index(base + c, 0);
329336
grad(off) += counts_per_category(c, v);
330337

331-
// diffs
332338
if (inclusion_indicator(v, v) != 0) {
339+
const int cnt = counts_per_category(c, v);
333340
for (int k = 1; k < num_groups; ++k) {
334341
off = main_index(base + c, k);
335-
grad(off) += counts_per_category(c, v) * projection(g, k - 1);
342+
grad(off) += cnt * projection(g, k - 1);
336343
}
337344
}
338345
}
@@ -344,7 +351,6 @@ arma::vec gradient(
344351
off = main_index(base + 1, 0);
345352
grad(off) += blume_capel_stats(1, v);
346353

347-
// diffs
348354
if (inclusion_indicator(v, v) != 0) {
349355
for (int k = 1; k < num_groups; ++k) {
350356
off = main_index(base, k);
@@ -358,7 +364,6 @@ arma::vec gradient(
358364
}
359365

360366
// Pairwise (observed)
361-
arma::mat pairwise_stats = pairwise_stats_group[g];
362367
for (int v1 = 0; v1 < num_variables - 1; ++v1) {
363368
for (int v2 = v1 + 1; v2 < num_variables; ++v2) {
364369
const int row = pairwise_effect_indices(v1, v2);
@@ -382,9 +387,9 @@ arma::vec gradient(
382387
const int r0 = group_indices(g, 0);
383388
const int r1 = group_indices(g, 1);
384389

385-
arma::mat main_group(num_variables, max_num_categories, arma::fill::zeros);
386-
arma::mat pairwise_group(num_variables, num_variables, arma::fill::zeros);
387-
const arma::vec proj_g = projection.row(g).t(); // length = num_groups-1
390+
main_group.zeros();
391+
pairwise_group.zeros();
392+
arma::vec proj_g = projection.row(g).t(); // length = num_groups-1
388393

389394
// build group-specific params
390395
for (int v = 0; v < num_variables; ++v) {
@@ -393,8 +398,7 @@ arma::vec gradient(
393398
);
394399
main_group(v, arma::span(0, me.n_elem - 1)) = me.t();
395400

396-
for (int u = v; u < num_variables; ++u) { // Combines with loop over v
397-
if(u == v) continue;
401+
for (int u = v + 1; u < num_variables; ++u) {
398402
double w = compute_group_pairwise_effects(
399403
v, u, num_groups, pairwise_effects, pairwise_effect_indices,
400404
inclusion_indicator, proj_g
@@ -415,15 +419,15 @@ arma::vec gradient(
415419

416420
arma::vec rest_score = residual_matrix.col(v);
417421
arma::vec bound = K * rest_score;
418-
bound = arma::clamp(bound, 0.0, arma::datum::inf);
422+
bound.clamp(0.0, arma::datum::inf);
419423

420-
arma::mat exponents(num_group_obs, K + 1, arma::fill::zeros);
424+
exponents.set_size(num_group_obs, K + 1);
425+
exponents.zeros();
421426

422427
if (is_ordinal_variable(v)) {
423428
exponents.col(0) -= bound;
424-
arma::vec main_param = main_group.row(v).cols(0, K - 1).t();
425-
for (int j = 0; j < K; j++) {
426-
exponents.col(j+1) = main_param(j) + (j + 1) * rest_score - bound;
429+
for (int j = 0; j < K; ++j) {
430+
exponents.col(j + 1) = main_group(v, j) + (j + 1) * rest_score - bound;
427431
}
428432
} else {
429433
const double lin_effect = main_group(v, 0);
@@ -436,41 +440,56 @@ arma::vec gradient(
436440
}
437441
}
438442

439-
arma::mat probs = arma::exp(exponents);
440-
arma::vec denom = arma::sum(probs, 1); // base term
441-
probs.each_col() /= denom;
443+
// exponentiate + normalize row-wise (in-place)
444+
for (arma::uword i = 0; i < exponents.n_rows; ++i) {
445+
double denom = 0.0;
446+
for (arma::uword j = 0; j < exponents.n_cols; ++j) {
447+
double val = MY_EXP(exponents(i, j));
448+
exponents(i, j) = val;
449+
denom += val;
450+
}
451+
double inv = 1.0 / denom;
452+
for (arma::uword j = 0; j < exponents.n_cols; ++j)
453+
exponents(i, j) *= inv;
454+
}
442455

443456
// ---- MAIN expected ----
444457
const int base = main_effect_indices(v, 0);
445-
446458
if (is_ordinal_variable(v)) {
447459
for (int s = 1; s <= K; ++s) {
448460
const int j = s - 1;
461+
double sum_col_s = arma::accu(exponents.col(s));
462+
449463
off = main_index(base + j, 0);
450-
grad(off) -= arma::accu(probs.col(s));
464+
grad(off) -= sum_col_s;
451465

452466
if (inclusion_indicator(v, v) == 0) continue;
453467
for (int k = 1; k < num_groups; ++k) {
454468
off = main_index(base + j, k);
455-
grad(off) -= projection(g, k - 1) * arma::accu(probs.col(s));
469+
grad(off) -= projection(g, k - 1) * sum_col_s;
456470
}
457471
}
458472
} else {
459-
arma::vec lin_score = arma::regspace<arma::vec>(0, K); // length K+1
460-
arma::vec quad_score = arma::square(lin_score - ref);
473+
if (lin_score.n_elem != K + 1) {
474+
lin_score = arma::regspace<arma::vec>(0, K);
475+
quad_score = arma::square(lin_score - ref);
476+
}
461477

462-
off = main_index(base, 0);
463-
grad(off) -= arma::accu(probs * lin_score);
478+
double sum_lin = arma::accu(exponents * lin_score);
479+
double sum_quad = arma::accu(exponents * quad_score);
464480

481+
off = main_index(base, 0);
482+
grad(off) -= sum_lin;
465483
off = main_index(base + 1, 0);
466-
grad(off) -= arma::accu(probs * quad_score);
484+
grad(off) -= sum_quad;
467485

468486
if (inclusion_indicator(v, v) == 0) continue;
469487
for (int k = 1; k < num_groups; ++k) {
470488
off = main_index(base, k);
471-
grad(off) -= projection(g, k - 1) * arma::accu(probs * lin_score);
489+
grad(off) -= projection(g, k - 1) * sum_lin;
490+
472491
off = main_index(base + 1, k);
473-
grad(off) -= projection(g, k - 1) * arma::accu(probs * quad_score);
492+
grad(off) -= projection(g, k - 1) * sum_quad;
474493
}
475494
}
476495

@@ -479,21 +498,20 @@ arma::vec gradient(
479498
if (v == v2) continue;
480499

481500
const int row = (v < v2) ? pairwise_effect_indices(v, v2)
482-
: pairwise_effect_indices(v2, v);
501+
: pairwise_effect_indices(v2, v);
483502

484-
arma::vec expected_value(num_group_obs, arma::fill::zeros);
485-
486-
for (int s = 1; s <= K; ++s) {
487-
expected_value += s * probs.col(s) % obs.col(v2);
488-
}
503+
double total = 0.0;
504+
for (int s = 1; s <= K; ++s)
505+
total += s * arma::dot(exponents.col(s), obs.col(v2));
489506

490507
off = pair_index(row, 0);
491-
grad(off) -= arma::accu(expected_value);
508+
grad(off) -= total;
492509

493-
if (inclusion_indicator(v, v2) == 0) continue;
494-
for (int k = 1; k < num_groups; ++k) {
495-
off = pair_index(row, k);
496-
grad(off) -= projection(g, k - 1) * arma::accu(expected_value);
510+
if (inclusion_indicator(v, v2)) {
511+
for (int k = 1; k < num_groups; ++k) {
512+
off = pair_index(row, k);
513+
grad(off) -= projection(g, k - 1) * total;
514+
}
497515
}
498516
}
499517
}

src/bgmCompare_logp_and_grad.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ arma::vec gradient(
4646
const double main_beta,
4747
const double interaction_scale,
4848
const double difference_scale,
49-
const arma::imat main_index,
50-
const arma::imat pair_index
49+
const arma::imat& main_index,
50+
const arma::imat& pair_index
5151
);
5252

5353
double log_pseudoposterior_main_component(

0 commit comments

Comments
 (0)