@@ -281,8 +281,8 @@ arma::vec gradient(
281281 const double main_beta,
282282 const double interaction_scale,
283283 const double difference_scale,
284- const arma::imat main_index,
285- const arma::imat pair_index
284+ const arma::imat& main_index,
285+ const arma::imat& pair_index
286286) {
287287 const int num_variables = observations.n_cols ;
288288 const int max_num_categories = num_categories.max ();
@@ -300,22 +300,30 @@ arma::vec gradient(
300300 total_len += static_cast <long long >(r1 - r0 + 1 ) * (num_groups - 1 );
301301 }
302302 }
303- for (int v1 = 0 ; v1 < num_variables - 1 ; ++v1 ) {
304- for (int v2 = v1 + 1 ; v2 < num_variables; ++v2 ) {
305- if (inclusion_indicator (v1, v2) == 1 ) total_len += (num_groups - 1 );
303+ for (int v2 = 0 ; v2 < num_variables - 1 ; ++v2 ) {
304+ for (int v1 = v2 + 1 ; v1 < num_variables; ++v1 ) {
305+ total_len += (inclusion_indicator (v1, v2) == 1 ) * (num_groups - 1 );
306306 }
307307 }
308308
309309 arma::vec grad (total_len, arma::fill::zeros);
310310 int off;
311311
312+ // -------------------------------------------------
313+ // Allocate temporaries ONCE (reused inside loops)
314+ // -------------------------------------------------
315+ arma::mat main_group (num_variables, max_num_categories, arma::fill::zeros);
316+ arma::mat pairwise_group (num_variables, num_variables, arma::fill::zeros);
317+ arma::mat exponents; // resized per variable
318+ arma::vec lin_score, quad_score;
319+
312320 // -------------------------------
313321 // Observed sufficient statistics
314322 // -------------------------------
315323 for (int g = 0 ; g < num_groups; ++g) {
316- // list access
317- arma::imat counts_per_category = counts_per_category_group [g];
318- arma::imat blume_capel_stats = blume_capel_stats_group [g];
324+ const arma::imat& counts_per_category = counts_per_category_group[g];
325+ const arma::imat& blume_capel_stats = blume_capel_stats_group [g];
326+ const arma::mat& pairwise_stats = pairwise_stats_group [g];
319327
320328 // Main effects
321329 for (int v = 0 ; v < num_variables; ++v) {
@@ -324,15 +332,14 @@ arma::vec gradient(
324332
325333 if (is_ordinal_variable (v)) {
326334 for (int c = 0 ; c < num_cats; ++c) {
327- // overall
328335 off = main_index (base + c, 0 );
329336 grad (off) += counts_per_category (c, v);
330337
331- // diffs
332338 if (inclusion_indicator (v, v) != 0 ) {
339+ const int cnt = counts_per_category (c, v);
333340 for (int k = 1 ; k < num_groups; ++k) {
334341 off = main_index (base + c, k);
335- grad (off) += counts_per_category (c, v) * projection (g, k - 1 );
342+ grad (off) += cnt * projection (g, k - 1 );
336343 }
337344 }
338345 }
@@ -344,7 +351,6 @@ arma::vec gradient(
344351 off = main_index (base + 1 , 0 );
345352 grad (off) += blume_capel_stats (1 , v);
346353
347- // diffs
348354 if (inclusion_indicator (v, v) != 0 ) {
349355 for (int k = 1 ; k < num_groups; ++k) {
350356 off = main_index (base, k);
@@ -358,7 +364,6 @@ arma::vec gradient(
358364 }
359365
360366 // Pairwise (observed)
361- arma::mat pairwise_stats = pairwise_stats_group[g];
362367 for (int v1 = 0 ; v1 < num_variables - 1 ; ++v1) {
363368 for (int v2 = v1 + 1 ; v2 < num_variables; ++v2) {
364369 const int row = pairwise_effect_indices (v1, v2);
@@ -382,9 +387,9 @@ arma::vec gradient(
382387 const int r0 = group_indices (g, 0 );
383388 const int r1 = group_indices (g, 1 );
384389
385- arma::mat main_group (num_variables, max_num_categories, arma::fill:: zeros);
386- arma::mat pairwise_group (num_variables, num_variables, arma::fill:: zeros);
387- const arma::vec proj_g = projection.row (g).t (); // length = num_groups-1
390+ main_group. zeros ( );
391+ pairwise_group. zeros ( );
392+ arma::vec proj_g = projection.row (g).t (); // length = num_groups-1
388393
389394 // build group-specific params
390395 for (int v = 0 ; v < num_variables; ++v) {
@@ -393,8 +398,7 @@ arma::vec gradient(
393398 );
394399 main_group (v, arma::span (0 , me.n_elem - 1 )) = me.t ();
395400
396- for (int u = v; u < num_variables; ++u) { // Combines with loop over v
397- if (u == v) continue ;
401+ for (int u = v + 1 ; u < num_variables; ++u) {
398402 double w = compute_group_pairwise_effects (
399403 v, u, num_groups, pairwise_effects, pairwise_effect_indices,
400404 inclusion_indicator, proj_g
@@ -415,15 +419,15 @@ arma::vec gradient(
415419
416420 arma::vec rest_score = residual_matrix.col (v);
417421 arma::vec bound = K * rest_score;
418- bound = arma:: clamp (bound, 0.0 , arma::datum::inf);
422+ bound. clamp (0.0 , arma::datum::inf);
419423
420- arma::mat exponents (num_group_obs, K + 1 , arma::fill::zeros);
424+ exponents.set_size (num_group_obs, K + 1 );
425+ exponents.zeros ();
421426
422427 if (is_ordinal_variable (v)) {
423428 exponents.col (0 ) -= bound;
424- arma::vec main_param = main_group.row (v).cols (0 , K - 1 ).t ();
425- for (int j = 0 ; j < K; j++) {
426- exponents.col (j+1 ) = main_param (j) + (j + 1 ) * rest_score - bound;
429+ for (int j = 0 ; j < K; ++j) {
430+ exponents.col (j + 1 ) = main_group (v, j) + (j + 1 ) * rest_score - bound;
427431 }
428432 } else {
429433 const double lin_effect = main_group (v, 0 );
@@ -436,41 +440,56 @@ arma::vec gradient(
436440 }
437441 }
438442
439- arma::mat probs = arma::exp (exponents);
440- arma::vec denom = arma::sum (probs, 1 ); // base term
441- probs.each_col () /= denom;
443+ // exponentiate + normalize row-wise (in-place)
444+ for (arma::uword i = 0 ; i < exponents.n_rows ; ++i) {
445+ double denom = 0.0 ;
446+ for (arma::uword j = 0 ; j < exponents.n_cols ; ++j) {
447+ double val = MY_EXP (exponents (i, j));
448+ exponents (i, j) = val;
449+ denom += val;
450+ }
451+ double inv = 1.0 / denom;
452+ for (arma::uword j = 0 ; j < exponents.n_cols ; ++j)
453+ exponents (i, j) *= inv;
454+ }
442455
443456 // ---- MAIN expected ----
444457 const int base = main_effect_indices (v, 0 );
445-
446458 if (is_ordinal_variable (v)) {
447459 for (int s = 1 ; s <= K; ++s) {
448460 const int j = s - 1 ;
461+ double sum_col_s = arma::accu (exponents.col (s));
462+
449463 off = main_index (base + j, 0 );
450- grad (off) -= arma::accu (probs. col (s)) ;
464+ grad (off) -= sum_col_s ;
451465
452466 if (inclusion_indicator (v, v) == 0 ) continue ;
453467 for (int k = 1 ; k < num_groups; ++k) {
454468 off = main_index (base + j, k);
455- grad (off) -= projection (g, k - 1 ) * arma::accu (probs. col (s)) ;
469+ grad (off) -= projection (g, k - 1 ) * sum_col_s ;
456470 }
457471 }
458472 } else {
459- arma::vec lin_score = arma::regspace<arma::vec>(0 , K); // length K+1
460- arma::vec quad_score = arma::square (lin_score - ref);
473+ if (lin_score.n_elem != K + 1 ) {
474+ lin_score = arma::regspace<arma::vec>(0 , K);
475+ quad_score = arma::square (lin_score - ref);
476+ }
461477
462- off = main_index (base, 0 );
463- grad (off) - = arma::accu (probs * lin_score );
478+ double sum_lin = arma::accu (exponents * lin_score );
479+ double sum_quad = arma::accu (exponents * quad_score );
464480
481+ off = main_index (base, 0 );
482+ grad (off) -= sum_lin;
465483 off = main_index (base + 1 , 0 );
466- grad (off) -= arma::accu (probs * quad_score) ;
484+ grad (off) -= sum_quad ;
467485
468486 if (inclusion_indicator (v, v) == 0 ) continue ;
469487 for (int k = 1 ; k < num_groups; ++k) {
470488 off = main_index (base, k);
471- grad (off) -= projection (g, k - 1 ) * arma::accu (probs * lin_score);
489+ grad (off) -= projection (g, k - 1 ) * sum_lin;
490+
472491 off = main_index (base + 1 , k);
473- grad (off) -= projection (g, k - 1 ) * arma::accu (probs * quad_score) ;
492+ grad (off) -= projection (g, k - 1 ) * sum_quad ;
474493 }
475494 }
476495
@@ -479,21 +498,20 @@ arma::vec gradient(
479498 if (v == v2) continue ;
480499
481500 const int row = (v < v2) ? pairwise_effect_indices (v, v2)
482- : pairwise_effect_indices (v2, v);
501+ : pairwise_effect_indices (v2, v);
483502
484- arma::vec expected_value (num_group_obs, arma::fill::zeros);
485-
486- for (int s = 1 ; s <= K; ++s) {
487- expected_value += s * probs.col (s) % obs.col (v2);
488- }
503+ double total = 0.0 ;
504+ for (int s = 1 ; s <= K; ++s)
505+ total += s * arma::dot (exponents.col (s), obs.col (v2));
489506
490507 off = pair_index (row, 0 );
491- grad (off) -= arma::accu (expected_value) ;
508+ grad (off) -= total ;
492509
493- if (inclusion_indicator (v, v2) == 0 ) continue ;
494- for (int k = 1 ; k < num_groups; ++k) {
495- off = pair_index (row, k);
496- grad (off) -= projection (g, k - 1 ) * arma::accu (expected_value);
510+ if (inclusion_indicator (v, v2)) {
511+ for (int k = 1 ; k < num_groups; ++k) {
512+ off = pair_index (row, k);
513+ grad (off) -= projection (g, k - 1 ) * total;
514+ }
497515 }
498516 }
499517 }
0 commit comments