Skip to content

Commit f30736e

Browse files
Fix numeric stability normalizing constants for BC variables for bgm().
1 parent 6134533 commit f30736e

File tree

4 files changed

+75
-14
lines changed

4 files changed

+75
-14
lines changed

R/bgm.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ bgm = function(
418418
} else if(update_method == "hamiltonian-mc") {
419419
target_accept = 0.65
420420
} else if(update_method == "nuts") {
421-
target_accept = 0.60
421+
target_accept = 0.80
422422
}
423423
}
424424

src/bgm_logp_and_grad.cpp

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,21 @@ double log_pseudoposterior (
554554
const double lin_effect = main_effects(variable, 0);
555555
const double quad_effect = main_effects(variable, 1);
556556

557+
// ----
558+
const int score_min = -ref;
559+
const int score_max = num_cats - ref;
560+
const int max_diff = (std::abs(score_min) >= std::abs(score_max)) ? score_min : score_max;
561+
bound = max_diff * residual_score; // numerical bound vector
562+
563+
double main_bound = lin_effect * score_min + quad_effect * score_min * score_min;
564+
for (int cat = 1; cat <= num_cats; cat++) {
565+
const int score = cat - ref;
566+
const double tmp = lin_effect * score + quad_effect * score * score;
567+
if (std::abs(tmp) > std::abs(main_bound)) main_bound = tmp;
568+
}
569+
bound += main_bound; // final bound adjustment
570+
// ----
571+
557572
for (int cat = 0; cat <= num_cats; cat++) {
558573
int score = cat - ref; // centered category
559574
double lin = lin_effect * score; // precompute linear term
@@ -683,9 +698,9 @@ arma::vec gradient_log_pseudoposterior(
683698
for (int variable = 0; variable < num_variables; variable++) {
684699
const int num_cats = num_categories(variable);
685700
arma::vec residual_score = residual_matrix.col(variable);
686-
arma::vec bound = num_cats * residual_score;
687701

688702
if (is_ordinal_variable(variable)) {
703+
arma::vec bound = num_cats * residual_score;
689704
arma::vec main_param = main_effects.row(variable).cols(0, num_cats - 1).t();
690705
arma::mat probs = compute_probs(
691706
main_param, residual_score, bound, num_cats
@@ -708,22 +723,69 @@ arma::vec gradient_log_pseudoposterior(
708723
}
709724
offset += num_cats;
710725
} else {
726+
711727
const int ref = baseline_category(variable);
712728
const double lin_eff = main_effects(variable, 0);
713729
const double quad_eff = main_effects(variable, 1);
714730

715-
arma::mat exponents(num_persons, num_cats + 1);
716-
for (int cat = 0; cat <= num_cats; cat++) {
717-
int score = cat - ref;
718-
double lin = lin_eff * score;
719-
double quad = quad_eff * score * score;
720-
exponents.col(cat) = lin + quad + score * residual_score - bound;
731+
// Compute bounds
732+
const int score_min = -ref;
733+
const int score_max = num_cats - ref;
734+
const int max_diff = (std::abs(score_min) >= std::abs(score_max)) ? score_min : score_max;
735+
arma::vec bound = max_diff * residual_score;
736+
737+
double main_bound = lin_eff * score_min + quad_eff * score_min * score_min;
738+
for (int cat = 1; cat <= num_cats; cat++) {
739+
const int score = cat - ref;
740+
const double tmp = lin_eff * score + quad_eff * score * score;
741+
if (std::abs(tmp) > std::abs(main_bound)) main_bound = tmp;
721742
}
743+
bound += main_bound;
744+
745+
// arma::mat exponents(num_persons, num_cats + 1);
746+
// for (int cat = 0; cat <= num_cats; cat++) {
747+
// int score = cat - ref;
748+
// double lin = lin_eff * score;
749+
// double quad = quad_eff * score * score;
750+
// exponents.col(cat) = lin + quad + score * residual_score - bound;
751+
// }
752+
// arma::mat probs = ARMA_MY_EXP(exponents);
753+
// arma::vec denom = arma::sum(probs, 1);
754+
// probs.each_col() /= denom;
755+
756+
// arma::ivec lin_score = arma::regspace<arma::ivec>(0 - ref, num_cats - ref);
757+
// arma::ivec quad_score = arma::square(lin_score);
758+
759+
// Compute exponents
760+
arma::vec scores = arma::regspace<arma::vec>(0 - ref, num_cats - ref);
761+
arma::rowvec offsets = lin_eff * scores.t() + quad_eff * arma::square(scores.t());
762+
arma::mat exponents = residual_score * scores.t();
763+
exponents.each_row() += offsets;
764+
exponents.each_col() -= bound;
765+
arma::vec row_max = arma::max(exponents, /*dim=*/1);
766+
exponents.each_col() -= row_max;
767+
768+
// Compute probabilities
722769
arma::mat probs = ARMA_MY_EXP(exponents);
723770
arma::vec denom = arma::sum(probs, 1);
771+
// Guard against zeros/NaNs in denom (can happen if all entries underflow to 0)
772+
arma::uvec bad = arma::find_nonfinite(denom);
773+
bad = arma::join_vert(bad, arma::find(denom <= 0));
774+
bad = arma::unique(bad);
775+
if (!bad.is_empty()) {
776+
// Fallback: make the max-exponent entry 1 and others 0 for those rows
777+
// (softmax limit)
778+
arma::uvec idx_max = arma::index_max(exponents.rows(bad), 1);
779+
probs.rows(bad).zeros();
780+
for (arma::uword i = 0; i < bad.n_elem; ++i) {
781+
probs(bad(i), idx_max(i)) = 1.0;
782+
}
783+
// fix denom to 1 for those rows so the division below is safe
784+
denom.elem(bad).ones();
785+
}
724786
probs.each_col() /= denom;
725787

726-
arma::ivec lin_score = arma::regspace<arma::ivec>(0 - ref, num_cats - ref);
788+
arma::ivec lin_score = arma::conv_to<arma::ivec>::from(scores);
727789
arma::ivec quad_score = arma::square(lin_score);
728790

729791
// main effects
@@ -772,7 +834,6 @@ arma::vec gradient_log_pseudoposterior(
772834
gradient(location) -= 2.0 * effect / (effect * effect + pairwise_scale * pairwise_scale);
773835
}
774836
}
775-
776837
return gradient;
777838
}
778839

src/mcmc_utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ double kinetic_energy(const arma::vec& r, const arma::vec& inv_mass_diag) {
3838
* - theta: Current parameter vector.
3939
* - log_post: Function to compute log posterior.
4040
* - grad: Function to compute gradient of log posterior.
41-
* - target_acceptance: Target acceptance rate (default 0.65).
41+
* - target_acceptance: Target acceptance rate (default 0.8).
4242
* - init_step: Initial step size to try (default 1.0).
43-
* - max_attempts: Max number of doubling/halving attempts (default 20).
43+
* - max_attempts: Max number of doubling/halving attempts (default 100).
4444
*
4545
* Returns:
4646
* - A step size epsilon resulting in log acceptance near -log(2).

src/mcmc_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ double heuristic_initial_step_size(
218218
const std::function<double(const arma::vec&)>& log_post,
219219
const std::function<arma::vec(const arma::vec&)>& grad,
220220
SafeRNG& rng,
221-
double target_acceptance = 0.625,
221+
double target_acceptance = 0.8,
222222
double init_step = 1.0,
223-
int max_attempts = 20
223+
int max_attempts = 100
224224
);

0 commit comments

Comments
 (0)