@@ -87,80 +87,6 @@ inline arma::vec compute_denom_ordinal(const arma::vec& residual,
8787
8888
8989
90- // inline arma::vec compute_denom_blumecapel(const arma::vec& residual,
91- // const arma::vec& bound,
92- // int num_cats,
93- // double theta_lin,
94- // double theta_quad,
95- // int ref)
96- // {
97- // constexpr double EXP_BOUND = 709.0;
98- // const arma::uword N = bound.n_elem;
99- //
100- // const int C = num_cats + 1; // number of terms
101- // std::vector<double> A(C);
102- // std::vector<int> S(C);
103- // for (int cat = 0; cat <= num_cats; ++cat) {
104- // const int s = cat - ref;
105- // S[cat] = s;
106- // A[cat] = ARMA_MY_EXP(theta_lin * s + theta_quad * double(s) * double(s));
107- // }
108- //
109- // arma::vec denom(N, arma::fill::none);
110- //
111- // // Fast block: factor via eB and power chain for exp(score * r)
112- // auto fast_block = [&](arma::uword i0, arma::uword i1) {
113- // arma::vec r = residual.rows(i0, i1);
114- // arma::vec b = bound.rows(i0, i1);
115- // arma::vec eB = ARMA_MY_EXP(-b);
116- // arma::vec eR = ARMA_MY_EXP(r);
117- //
118- // // Start power at s_min = -ref: pow = exp(s_min * r) = exp(r)^{s_min}
119- // arma::vec pow = ARMA_MY_EXP( double(-ref) * r );
120- //
121- // arma::vec d(pow.n_elem, arma::fill::zeros);
122- //
123- // for (int cat = 0; cat <= num_cats; ++cat) {
124- // // term = A_c * exp(score*r) * exp(-b)
125- // arma::vec t = A[cat] * pow % eB;
126- // d += t;
127- // pow %= eR; // score increments by +1 each step
128- // }
129- // denom.rows(i0, i1) = d;
130- // };
131- //
132- // // Safe block: stabilized exponent; compute per-category exponents directly
133- // auto safe_block = [&](arma::uword i0, arma::uword i1) {
134- // arma::vec r = residual.rows(i0, i1);
135- // arma::vec b = bound.rows(i0, i1);
136- // arma::vec d = ARMA_MY_EXP(-b); // see note above re "+ e^{-b}" term
137- // for (int cat = 0; cat <= num_cats; ++cat) {
138- // const int s = S[cat];
139- // arma::vec ex = theta_lin * s + theta_quad * double(s) * double(s)
140- // + double(s) * r - b;
141- // d += ARMA_MY_EXP(ex);
142- // }
143- // denom.rows(i0, i1) = d;
144- // };
145- //
146- // // Span-wise over contiguous runs
147- // const double* bp = bound.memptr();
148- // for (arma::uword i = 0; i < N; ) {
149- // const bool fast = !(bp[i] < -EXP_BOUND || bp[i] > EXP_BOUND);
150- // arma::uword j = i + 1;
151- // while (j < N) {
152- // const bool fj = !(bp[j] < -EXP_BOUND || bp[j] > EXP_BOUND);
153- // if (fj != fast) break;
154- // ++j;
155- // }
156- // if (fast) fast_block(i, j - 1); else safe_block(i, j - 1);
157- // i = j;
158- // }
159- // return denom;
160- // }
161-
162-
163-
16490/* *
16591 * Compute category probabilities in a numerically stable manner.
16692 *
@@ -175,10 +101,10 @@ inline arma::vec compute_denom_ordinal(const arma::vec& residual,
175101 * Returns:
176102 * probs: num_persons × num_cats matrix of probabilities (row-normalized)
177103 */
178- inline arma::mat compute_probs (const arma::vec& main_param,
179- const arma::vec& residual_score,
180- const arma::vec& bound,
181- int num_cats)
104+ inline arma::mat compute_probs_ordinal (const arma::vec& main_param,
105+ const arma::vec& residual_score,
106+ const arma::vec& bound,
107+ int num_cats)
182108{
183109 constexpr double EXP_BOUND = 709.0 ;
184110 const arma::uword N = bound.n_elem ;
@@ -554,15 +480,17 @@ double log_pseudoposterior (
554480 const double lin_effect = main_effects (variable, 0 );
555481 const double quad_effect = main_effects (variable, 1 );
556482
483+ arma::mat exponents (num_persons, num_cats + 1 );
557484 for (int cat = 0 ; cat <= num_cats; cat++) {
558485 int score = cat - ref; // centered category
559486 double lin = lin_effect * score; // precompute linear term
560487 double quad = quad_effect * score * score; // precompute quadratic term
561- arma::vec exponent = lin + quad + score * residual_score - bound;
562- denom += ARMA_MY_EXP (exponent); // accumulate over categories
488+ exponents.col (cat) = lin + quad + cat * residual_score;
563489 }
490+ bound = arma::max (exponents, /* dim=*/ 1 );
491+ exponents.each_col () -= bound;
492+ denom = arma::sum (ARMA_MY_EXP (exponents), 1 );
564493 }
565-
566494 log_pseudoposterior -= arma::accu (bound + ARMA_MY_LOG (denom)); // total contribution
567495 }
568496
@@ -687,7 +615,7 @@ arma::vec gradient_log_pseudoposterior(
687615
688616 if (is_ordinal_variable (variable)) {
689617 arma::vec main_param = main_effects.row (variable).cols (0 , num_cats - 1 ).t ();
690- arma::mat probs = compute_probs (
618+ arma::mat probs = compute_probs_ordinal (
691619 main_param, residual_score, bound, num_cats
692620 );
693621
@@ -711,19 +639,39 @@ arma::vec gradient_log_pseudoposterior(
711639 const int ref = baseline_category (variable);
712640 const double lin_eff = main_effects (variable, 0 );
713641 const double quad_eff = main_effects (variable, 1 );
642+ arma::vec scores = arma::regspace<arma::vec>(0 - ref, num_cats - ref);
714643
715644 arma::mat exponents (num_persons, num_cats + 1 );
716645 for (int cat = 0 ; cat <= num_cats; cat++) {
717- int score = cat - ref;
718- double lin = lin_eff * score;
719- double quad = quad_eff * score * score;
720- exponents.col (cat) = lin + quad + score * residual_score - bound ;
646+ int score = cat - ref; // centered category
647+ double lin = lin_eff * score; // precompute linear term
648+ double quad = quad_eff * score * score; // precompute quadratic term
649+ exponents.col (cat) = lin + quad + cat * residual_score;
721650 }
651+ bound = arma::max (exponents, /* dim=*/ 1 );
652+ exponents.each_col () -= bound;
653+
654+ // Compute probabilities
722655 arma::mat probs = ARMA_MY_EXP (exponents);
723656 arma::vec denom = arma::sum (probs, 1 );
657+ // Guard against zeros/NaNs in denom (can happen if all entries underflow to 0)
658+ arma::uvec bad = arma::find_nonfinite (denom);
659+ bad = arma::join_vert (bad, arma::find (denom <= 0 ));
660+ bad = arma::unique (bad);
661+ if (!bad.is_empty ()) {
662+ // Fallback: make the max-exponent entry 1 and others 0 for those rows
663+ // (softmax limit)
664+ arma::uvec idx_max = arma::index_max (exponents.rows (bad), 1 );
665+ probs.rows (bad).zeros ();
666+ for (arma::uword i = 0 ; i < bad.n_elem ; ++i) {
667+ probs (bad (i), idx_max (i)) = 1.0 ;
668+ }
669+ // fix denom to 1 for those rows so the division below is safe
670+ denom.elem (bad).ones ();
671+ }
724672 probs.each_col () /= denom;
725673
726- arma::ivec lin_score = arma::regspace <arma::ivec>( 0 - ref, num_cats - ref );
674+ arma::ivec lin_score = arma::conv_to <arma::ivec>:: from (scores );
727675 arma::ivec quad_score = arma::square (lin_score);
728676
729677 // main effects
0 commit comments