@@ -554,6 +554,21 @@ double log_pseudoposterior (
554554 const double lin_effect = main_effects (variable, 0 );
555555 const double quad_effect = main_effects (variable, 1 );
556556
557+ // ----
558+ const int score_min = -ref;
559+ const int score_max = num_cats - ref;
560+ const int max_diff = (std::abs (score_min) >= std::abs (score_max)) ? score_min : score_max;
561+ bound = max_diff * residual_score; // numerical bound vector
562+
563+ double main_bound = lin_effect * score_min + quad_effect * score_min * score_min;
564+ for (int cat = 1 ; cat <= num_cats; cat++) {
565+ const int score = cat - ref;
566+ const double tmp = lin_effect * score + quad_effect * score * score;
567+ if (std::abs (tmp) > std::abs (main_bound)) main_bound = tmp;
568+ }
569+ bound += main_bound; // final bound adjustment
570+ // ----
571+
557572 for (int cat = 0 ; cat <= num_cats; cat++) {
558573 int score = cat - ref; // centered category
559574 double lin = lin_effect * score; // precompute linear term
@@ -683,9 +698,9 @@ arma::vec gradient_log_pseudoposterior(
683698 for (int variable = 0 ; variable < num_variables; variable++) {
684699 const int num_cats = num_categories (variable);
685700 arma::vec residual_score = residual_matrix.col (variable);
686- arma::vec bound = num_cats * residual_score;
687701
688702 if (is_ordinal_variable (variable)) {
703+ arma::vec bound = num_cats * residual_score;
689704 arma::vec main_param = main_effects.row (variable).cols (0 , num_cats - 1 ).t ();
690705 arma::mat probs = compute_probs (
691706 main_param, residual_score, bound, num_cats
@@ -708,22 +723,69 @@ arma::vec gradient_log_pseudoposterior(
708723 }
709724 offset += num_cats;
710725 } else {
726+
711727 const int ref = baseline_category (variable);
712728 const double lin_eff = main_effects (variable, 0 );
713729 const double quad_eff = main_effects (variable, 1 );
714730
715- arma::mat exponents (num_persons, num_cats + 1 );
716- for (int cat = 0 ; cat <= num_cats; cat++) {
717- int score = cat - ref;
718- double lin = lin_eff * score;
719- double quad = quad_eff * score * score;
720- exponents.col (cat) = lin + quad + score * residual_score - bound;
731+ // Compute bounds
732+ const int score_min = -ref;
733+ const int score_max = num_cats - ref;
734+ const int max_diff = (std::abs (score_min) >= std::abs (score_max)) ? score_min : score_max;
735+ arma::vec bound = max_diff * residual_score;
736+
737+ double main_bound = lin_eff * score_min + quad_eff * score_min * score_min;
738+ for (int cat = 1 ; cat <= num_cats; cat++) {
739+ const int score = cat - ref;
740+ const double tmp = lin_eff * score + quad_eff * score * score;
741+ if (std::abs (tmp) > std::abs (main_bound)) main_bound = tmp;
721742 }
743+ bound += main_bound;
744+
745+ // arma::mat exponents(num_persons, num_cats + 1);
746+ // for (int cat = 0; cat <= num_cats; cat++) {
747+ // int score = cat - ref;
748+ // double lin = lin_eff * score;
749+ // double quad = quad_eff * score * score;
750+ // exponents.col(cat) = lin + quad + score * residual_score - bound;
751+ // }
752+ // arma::mat probs = ARMA_MY_EXP(exponents);
753+ // arma::vec denom = arma::sum(probs, 1);
754+ // probs.each_col() /= denom;
755+
756+ // arma::ivec lin_score = arma::regspace<arma::ivec>(0 - ref, num_cats - ref);
757+ // arma::ivec quad_score = arma::square(lin_score);
758+
759+ // Compute exponents
760+ arma::vec scores = arma::regspace<arma::vec>(0 - ref, num_cats - ref);
761+ arma::rowvec offsets = lin_eff * scores.t () + quad_eff * arma::square (scores.t ());
762+ arma::mat exponents = residual_score * scores.t ();
763+ exponents.each_row () += offsets;
764+ exponents.each_col () -= bound;
765+ arma::vec row_max = arma::max (exponents, /* dim=*/ 1 );
766+ exponents.each_col () -= row_max;
767+
768+ // Compute probabilities
722769 arma::mat probs = ARMA_MY_EXP (exponents);
723770 arma::vec denom = arma::sum (probs, 1 );
771+ // Guard against zeros/NaNs in denom (can happen if all entries underflow to 0)
772+ arma::uvec bad = arma::find_nonfinite (denom);
773+ bad = arma::join_vert (bad, arma::find (denom <= 0 ));
774+ bad = arma::unique (bad);
775+ if (!bad.is_empty ()) {
776+ // Fallback: make the max-exponent entry 1 and others 0 for those rows
777+ // (softmax limit)
778+ arma::uvec idx_max = arma::index_max (exponents.rows (bad), 1 );
779+ probs.rows (bad).zeros ();
780+ for (arma::uword i = 0 ; i < bad.n_elem ; ++i) {
781+ probs (bad (i), idx_max (i)) = 1.0 ;
782+ }
783+ // fix denom to 1 for those rows so the division below is safe
784+ denom.elem (bad).ones ();
785+ }
724786 probs.each_col () /= denom;
725787
726- arma::ivec lin_score = arma::regspace <arma::ivec>( 0 - ref, num_cats - ref );
788+ arma::ivec lin_score = arma::conv_to <arma::ivec>:: from (scores );
727789 arma::ivec quad_score = arma::square (lin_score);
728790
729791 // main effects
@@ -772,7 +834,6 @@ arma::vec gradient_log_pseudoposterior(
772834 gradient (location) -= 2.0 * effect / (effect * effect + pairwise_scale * pairwise_scale);
773835 }
774836 }
775-
776837 return gradient;
777838}
778839
0 commit comments