@@ -148,12 +148,12 @@ double log_pseudoposterior(
148148
149149 if (is_ordinal_variable (v)) {
150150 // base term exp(-bound)
151- denom = arma::exp (-bound);
151+ denom = ARMA_MY_EXP (-bound);
152152 // main_effects from main_group
153153 for (int c = 0 ; c < num_cats; ++c) {
154154 const double th = main_group (v, c);
155155 const arma::vec exponent = th + (c + 1 ) * rest_score - bound;
156- denom += arma::exp (exponent);
156+ denom += ARMA_MY_EXP (exponent);
157157 }
158158 } else {
159159 // linear/quadratic main effects from main_group
@@ -165,11 +165,11 @@ double log_pseudoposterior(
165165 const double quad = quad_effect * centered * centered;
166166 const double lin = lin_effect * c;
167167 const arma::vec exponent = lin + quad + c * rest_score - bound;
168- denom += arma::exp (exponent);
168+ denom += ARMA_MY_EXP (exponent);
169169 }
170170 }
171171 // - sum_i [ bound_i + log denom_i ]
172- log_pp -= arma::accu (bound + arma::log (denom));
172+ log_pp -= arma::accu (bound + ARMA_MY_LOG (denom));
173173 }
174174 }
175175
@@ -281,8 +281,8 @@ arma::vec gradient(
281281 const double main_beta,
282282 const double interaction_scale,
283283 const double difference_scale,
284- const arma::imat main_index,
285- const arma::imat pair_index
284+ const arma::imat& main_index,
285+ const arma::imat& pair_index
286286) {
287287 const int num_variables = observations.n_cols ;
288288 const int max_num_categories = num_categories.max ();
@@ -300,15 +300,21 @@ arma::vec gradient(
300300 total_len += static_cast <long long >(r1 - r0 + 1 ) * (num_groups - 1 );
301301 }
302302 }
303- for (int v1 = 0 ; v1 < num_variables - 1 ; ++v1 ) {
304- for (int v2 = v1 + 1 ; v2 < num_variables; ++v2 ) {
305- if (inclusion_indicator (v1, v2) == 1 ) total_len += (num_groups - 1 );
303+ for (int v2 = 0 ; v2 < num_variables - 1 ; ++v2 ) {
304+ for (int v1 = v2 + 1 ; v1 < num_variables; ++v1 ) {
305+ total_len += (inclusion_indicator (v1, v2) == 1 ) * (num_groups - 1 );
306306 }
307307 }
308308
309309 arma::vec grad (total_len, arma::fill::zeros);
310310 int off;
311311
312+ // -------------------------------------------------
313+ // Allocate temporaries ONCE (reused inside loops)
314+ // -------------------------------------------------
315+ arma::mat main_group (num_variables, max_num_categories, arma::fill::none);
316+ arma::mat pairwise_group (num_variables, num_variables, arma::fill::none);
317+
312318 // -------------------------------
313319 // Observed sufficient statistics
314320 // -------------------------------
@@ -324,15 +330,14 @@ arma::vec gradient(
324330
325331 if (is_ordinal_variable (v)) {
326332 for (int c = 0 ; c < num_cats; ++c) {
327- // overall
328333 off = main_index (base + c, 0 );
329334 grad (off) += counts_per_category (c, v);
330335
331- // diffs
332336 if (inclusion_indicator (v, v) != 0 ) {
337+ const int cnt = counts_per_category (c, v);
333338 for (int k = 1 ; k < num_groups; ++k) {
334339 off = main_index (base + c, k);
335- grad (off) += counts_per_category (c, v) * projection (g, k - 1 );
340+ grad (off) += cnt * projection (g, k - 1 );
336341 }
337342 }
338343 }
@@ -344,7 +349,6 @@ arma::vec gradient(
344349 off = main_index (base + 1 , 0 );
345350 grad (off) += blume_capel_stats (1 , v);
346351
347- // diffs
348352 if (inclusion_indicator (v, v) != 0 ) {
349353 for (int k = 1 ; k < num_groups; ++k) {
350354 off = main_index (base, k);
@@ -382,9 +386,9 @@ arma::vec gradient(
382386 const int r0 = group_indices (g, 0 );
383387 const int r1 = group_indices (g, 1 );
384388
385- arma::mat main_group (num_variables, max_num_categories, arma::fill::zeros);
386- arma::mat pairwise_group (num_variables, num_variables, arma::fill::zeros);
387389 const arma::vec proj_g = projection.row (g).t (); // length = num_groups-1
390+ main_group.zeros ();
391+ pairwise_group.zeros ();
388392
389393 // build group-specific params
390394 for (int v = 0 ; v < num_variables; ++v) {
@@ -393,8 +397,7 @@ arma::vec gradient(
393397 );
394398 main_group (v, arma::span (0 , me.n_elem - 1 )) = me.t ();
395399
396- for (int u = v; u < num_variables; ++u) { // Combines with loop over v
397- if (u == v) continue ;
400+ for (int u = v + 1 ; u < num_variables; ++u) {
398401 double w = compute_group_pairwise_effects (
399402 v, u, num_groups, pairwise_effects, pairwise_effect_indices,
400403 inclusion_indicator, proj_g
@@ -415,15 +418,14 @@ arma::vec gradient(
415418
416419 arma::vec rest_score = residual_matrix.col (v);
417420 arma::vec bound = K * rest_score;
418- bound = arma:: clamp (bound, 0.0 , arma::datum::inf);
421+ bound. clamp (0.0 , arma::datum::inf);
419422
420- arma::mat exponents (num_group_obs, K + 1 , arma::fill::zeros );
423+ arma::mat exponents (num_group_obs, K + 1 , arma::fill::none );
421424
422425 if (is_ordinal_variable (v)) {
423- exponents.col (0 ) -= bound;
424- arma::vec main_param = main_group.row (v).cols (0 , K - 1 ).t ();
425- for (int j = 0 ; j < K; j++) {
426- exponents.col (j+1 ) = main_param (j) + (j + 1 ) * rest_score - bound;
426+ exponents.col (0 ) = -bound;
427+ for (int j = 0 ; j < K; ++j) {
428+ exponents.col (j + 1 ) = main_group (v, j) + (j + 1 ) * rest_score - bound;
427429 }
428430 } else {
429431 const double lin_effect = main_group (v, 0 );
@@ -442,58 +444,72 @@ arma::vec gradient(
442444
443445 // ---- MAIN expected ----
444446 const int base = main_effect_indices (v, 0 );
445-
446447 if (is_ordinal_variable (v)) {
447448 for (int s = 1 ; s <= K; ++s) {
448449 const int j = s - 1 ;
450+ double sum_col_s = arma::accu (probs.col (s));
451+
452+
449453 off = main_index (base + j, 0 );
450- grad (off) -= arma::accu (probs.col (s));
454+ grad (off) -= sum_col_s;
455+
451456
452457 if (inclusion_indicator (v, v) == 0 ) continue ;
453458 for (int k = 1 ; k < num_groups; ++k) {
454459 off = main_index (base + j, k);
455- grad (off) -= projection (g, k - 1 ) * arma::accu (probs.col (s));
460+ grad (off) -= projection (g, k - 1 ) * sum_col_s;
461+
456462 }
457463 }
458464 } else {
465+
459466 arma::vec lin_score = arma::regspace<arma::vec>(0 , K); // length K+1
460467 arma::vec quad_score = arma::square (lin_score - ref);
461468
462- off = main_index (base, 0 );
463- grad (off) - = arma::accu (probs * lin_score );
469+ double sum_lin = arma::accu (probs * lin_score );
470+ double sum_quad = arma::accu (probs * quad_score );
464471
472+ off = main_index (base, 0 );
473+ grad (off) -= sum_lin;
465474 off = main_index (base + 1 , 0 );
466- grad (off) -= arma::accu (probs * quad_score) ;
475+ grad (off) -= sum_quad ;
467476
468477 if (inclusion_indicator (v, v) == 0 ) continue ;
469478 for (int k = 1 ; k < num_groups; ++k) {
470479 off = main_index (base, k);
471- grad (off) -= projection (g, k - 1 ) * arma::accu (probs * lin_score);
480+ grad (off) -= projection (g, k - 1 ) * sum_lin;
481+
472482 off = main_index (base + 1 , k);
473- grad (off) -= projection (g, k - 1 ) * arma::accu (probs * quad_score) ;
483+ grad (off) -= projection (g, k - 1 ) * sum_quad ;
474484 }
475485 }
476486
477487 // ---- PAIRWISE expected ----
478488 for (int v2 = 0 ; v2 < num_variables; ++v2) {
479489 if (v == v2) continue ;
480490
481- const int row = (v < v2) ? pairwise_effect_indices (v, v2)
482- : pairwise_effect_indices (v2, v);
483-
484491 arma::vec expected_value (num_group_obs, arma::fill::zeros);
485-
486492 for (int s = 1 ; s <= K; ++s) {
487493 expected_value += s * probs.col (s) % obs.col (v2);
488494 }
495+ double sum_expectation = arma::accu (expected_value);
496+
497+ // this is mathematically equivalent but numerically different...
498+ // double sum_expectation = 0.0;
499+ // for (int s = 1; s <= K; ++s) {
500+ // sum_expectation += s * arma::dot(probs.col(s), obs.col(v2));
501+ // }
502+
503+ const int row = (v < v2) ? pairwise_effect_indices (v, v2)
504+ : pairwise_effect_indices (v2, v);
489505
490506 off = pair_index (row, 0 );
491- grad (off) -= arma::accu (expected_value) ;
507+ grad (off) -= sum_expectation ;
492508
493509 if (inclusion_indicator (v, v2) == 0 ) continue ;
494510 for (int k = 1 ; k < num_groups; ++k) {
495511 off = pair_index (row, k);
496- grad (off) -= projection (g, k - 1 ) * arma::accu (expected_value) ;
512+ grad (off) -= projection (g, k - 1 ) * sum_expectation ;
497513 }
498514 }
499515 }
@@ -706,12 +722,12 @@ double log_pseudoposterior_main_component(
706722 arma::vec denom (rest_score.n_elem , arma::fill::zeros);
707723 if (is_ordinal_variable (variable)) {
708724 // base term exp(-bound)
709- denom = arma::exp (-bound);
725+ denom = ARMA_MY_EXP (-bound);
710726 // main_effects from main_group
711727 for (int cat = 0 ; cat < num_cats; cat++) {
712728 const double th = main_group (variable, cat);
713729 const arma::vec exponent = th + (cat + 1 ) * rest_score - bound;
714- denom += arma::exp (exponent);
730+ denom += ARMA_MY_EXP (exponent);
715731 }
716732 } else {
717733 // linear/quadratic main effects from main_group
@@ -723,11 +739,11 @@ double log_pseudoposterior_main_component(
723739 const double quad = quad_effect * centered * centered;
724740 const double lin = lin_effect * cat;
725741 const arma::vec exponent = lin + quad + cat * rest_score - bound;
726- denom += arma::exp (exponent);
742+ denom += ARMA_MY_EXP (exponent);
727743 }
728744 }
729745 // - sum_i [ bound_i + log denom_i ]
730- log_pp -= arma::accu (bound + arma::log (denom));
746+ log_pp -= arma::accu (bound + ARMA_MY_LOG (denom));
731747 }
732748
733749 // ---- priors ----
@@ -890,12 +906,12 @@ double log_pseudoposterior_pair_component(
890906
891907 if (is_ordinal_variable (v)) {
892908 // base term exp(-bound)
893- denom = arma::exp (-bound);
909+ denom = ARMA_MY_EXP (-bound);
894910 // main_effects from main_group
895911 for (int c = 0 ; c < num_cats; ++c) {
896912 const double th = main_group (v, c);
897913 const arma::vec exponent = th + (c + 1 ) * rest_score - bound;
898- denom += arma::exp (exponent);
914+ denom += ARMA_MY_EXP (exponent);
899915 }
900916 } else {
901917 // linear/quadratic main effects from main_group
@@ -907,11 +923,11 @@ double log_pseudoposterior_pair_component(
907923 const double quad = quad_effect * centered * centered;
908924 const double lin = lin_effect * c;
909925 const arma::vec exponent = lin + quad + c * rest_score - bound;
910- denom += arma::exp (exponent);
926+ denom += ARMA_MY_EXP (exponent);
911927 }
912928 }
913929 // - sum_i [ bound_i + log denom_i ]
914- log_pp -= arma::accu (bound + arma::log (denom));
930+ log_pp -= arma::accu (bound + ARMA_MY_LOG (denom));
915931 }
916932 }
917933
@@ -1037,12 +1053,12 @@ double log_ratio_pseudolikelihood_constant_variable(
10371053 bound_current = num_cats * arma::clamp (rest_current, 0.0 , arma::datum::inf);
10381054 bound_proposed = num_cats * arma::clamp (rest_proposed, 0.0 , arma::datum::inf);
10391055
1040- denom_current = arma::exp (-bound_current);
1041- denom_proposed = arma::exp (-bound_proposed);
1056+ denom_current = ARMA_MY_EXP (-bound_current);
1057+ denom_proposed = ARMA_MY_EXP (-bound_proposed);
10421058
10431059 for (int c = 0 ; c < num_cats; ++c) {
1044- denom_current += arma::exp (main_current (c) + (c + 1 ) * rest_current - bound_current);
1045- denom_proposed += arma::exp (main_proposed (c) + (c + 1 ) * rest_proposed - bound_proposed);
1060+ denom_current += ARMA_MY_EXP (main_current (c) + (c + 1 ) * rest_current - bound_current);
1061+ denom_proposed += ARMA_MY_EXP (main_proposed (c) + (c + 1 ) * rest_proposed - bound_proposed);
10461062 }
10471063 } else {
10481064 // Blume-Capel: linear + quadratic
@@ -1063,14 +1079,14 @@ double log_ratio_pseudolikelihood_constant_variable(
10631079 bound_proposed = lbound + num_cats * arma::clamp (rest_proposed, 0.0 , arma::datum::inf);
10641080
10651081 for (int s = 0 ; s <= num_cats; ++s) {
1066- denom_current += arma::exp (const_current (s) + s * rest_current - bound_current);
1067- denom_proposed += arma::exp (const_proposed (s) + s * rest_proposed - bound_proposed);
1082+ denom_current += ARMA_MY_EXP (const_current (s) + s * rest_current - bound_current);
1083+ denom_proposed += ARMA_MY_EXP (const_proposed (s) + s * rest_proposed - bound_proposed);
10681084 }
10691085 }
10701086
10711087 // --- accumulate contribution ---
10721088 log_ratio += arma::accu ((bound_current - bound_proposed) +
1073- arma::log (denom_current) - arma::log (denom_proposed));
1089+ ARMA_MY_LOG (denom_current) - ARMA_MY_LOG (denom_proposed));
10741090 }
10751091
10761092 return log_ratio;
0 commit comments