Skip to content

Commit 3a68b40

Browse files
committed
fix: eigen decomposition error for low rank mass matrix
1 parent 29ca08b commit 3a68b40

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

src/chain.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M
416416

417417
if let Some(div_msg) = divergence_msg.as_mut() {
418418
if let Some(err) = div_info.and_then(|info| info.logp_function_error.as_ref()) {
419-
div_msg.append_value(format!("{}", err));
419+
div_msg.append_value(format!("{err}"));
420420
} else {
421421
div_msg.append_null();
422422
}

src/low_rank_mass_matrix.rs

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ use crate::{
1818

1919
fn mat_all_finite(mat: &MatRef<f64>) -> bool {
2020
let mut ok = true;
21-
faer::zip!(mat).for_each(|faer::unzip!(val)| ok = ok & val.is_finite());
21+
faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
2222
ok
2323
}
2424

2525
fn col_all_finite(mat: &ColRef<f64>) -> bool {
2626
let mut ok = true;
27-
faer::zip!(mat).for_each(|faer::unzip!(val)| ok = ok & val.is_finite());
27+
faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
2828
ok
2929
}
3030

@@ -414,10 +414,22 @@ impl LowRankMassMatrixStrategy {
414414
grads.col_as_slice_mut(i).copy_from_slice(&grad[..]);
415415
}
416416

417+
let Some((stds, vals, vecs)) = self.compute_update(draws, grads) else {
418+
return;
419+
};
420+
421+
matrix.update(math, stds, vals, vecs);
422+
}
423+
424+
fn compute_update(
425+
&self,
426+
mut draws: Mat<f64>,
427+
mut grads: Mat<f64>,
428+
) -> Option<(Col<f64>, Col<f64>, Mat<f64>)> {
417429
let stds = rescale_points(&mut draws, &mut grads);
418430

419-
let svd_draws = draws.thin_svd().expect("Failed to compute SVD");
420-
let svd_grads = grads.thin_svd().expect("Failed to compute SVD");
431+
let svd_draws = draws.thin_svd().ok()?;
432+
let svd_grads = grads.thin_svd().ok()?;
421433

422434
let subspace = faer::concat![[svd_draws.U(), svd_grads.U()]];
423435

@@ -428,7 +440,7 @@ impl LowRankMassMatrixStrategy {
428440
let draws_proj = subspace_basis.transpose() * (&draws);
429441
let grads_proj = subspace_basis.transpose() * (&grads);
430442

431-
let (vals, vecs) = estimate_mass_matrix(draws_proj, grads_proj, self.settings.gamma);
443+
let (vals, vecs) = estimate_mass_matrix(draws_proj, grads_proj, self.settings.gamma)?;
432444

433445
let filtered = vals
434446
.iter()
@@ -448,8 +460,7 @@ impl LowRankMassMatrixStrategy {
448460
.for_each(|(mut col, vals)| col.copy_from(vals));
449461

450462
let vecs = subspace_basis * vecs;
451-
452-
matrix.update(math, stds, vals, vecs);
463+
Some((stds, vals, vecs))
453464
}
454465
}
455466

@@ -490,7 +501,11 @@ fn rescale_points(draws: &mut Mat<f64>, grads: &mut Mat<f64>) -> Col<f64> {
490501
stds
491502
}
492503

493-
fn estimate_mass_matrix(draws: Mat<f64>, grads: Mat<f64>, gamma: f64) -> (Col<f64>, Mat<f64>) {
504+
fn estimate_mass_matrix(
505+
draws: Mat<f64>,
506+
grads: Mat<f64>,
507+
gamma: f64,
508+
) -> Option<(Col<f64>, Mat<f64>)> {
494509
let mut cov_draws = (&draws) * draws.transpose();
495510
let mut cov_grads = (&grads) * grads.transpose();
496511

@@ -509,22 +524,18 @@ fn estimate_mass_matrix(draws: Mat<f64>, grads: Mat<f64>, gamma: f64) -> (Col<f6
509524
.iter_mut()
510525
.for_each(|x| *x += 1f64);
511526

512-
let mean = spd_mean(cov_draws, cov_grads);
527+
let mean = spd_mean(cov_draws, cov_grads)?;
513528

514-
let mean_eig = mean
515-
.self_adjoint_eigen(faer::Side::Lower)
516-
.expect("Failed to compute eigenvalue decomposition");
529+
let mean_eig = mean.self_adjoint_eigen(faer::Side::Lower).ok()?;
517530

518-
(
531+
Some((
519532
mean_eig.S().column_vector().to_owned(),
520533
mean_eig.U().to_owned(),
521-
)
534+
))
522535
}
523536

524-
fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Mat<f64> {
525-
let eigs_grads = cov_grads
526-
.self_adjoint_eigen(faer::Side::Lower)
527-
.expect("Failed to compute eigenvalue decomposition");
537+
fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Option<Mat<f64>> {
538+
let eigs_grads = cov_grads.self_adjoint_eigen(faer::Side::Lower).ok()?;
528539

529540
let u = eigs_grads.U();
530541
let eigs = eigs_grads.S().column_vector().to_owned();
@@ -534,9 +545,7 @@ fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Mat<f64> {
534545
let cov_grads_sqrt = u * eigs_sqrt.into_diagonal() * u.transpose();
535546
let m = (&cov_grads_sqrt) * cov_draws * cov_grads_sqrt;
536547

537-
let m_eig = m
538-
.self_adjoint_eigen(faer::Side::Lower)
539-
.expect("Failed to compute eigenvalue decomposition");
548+
let m_eig = m.self_adjoint_eigen(faer::Side::Lower).ok()?;
540549

541550
let m_u = m_eig.U();
542551
let mut m_s = m_eig.S().column_vector().to_owned();
@@ -550,7 +559,7 @@ fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Mat<f64> {
550559
.for_each(|val| *val = val.sqrt().recip());
551560
let grads_inv_sqrt = u * eigs_grads_inv.into_diagonal() * u.transpose();
552561

553-
(&grads_inv_sqrt) * m_sqrt * grads_inv_sqrt
562+
Some((&grads_inv_sqrt) * m_sqrt * grads_inv_sqrt)
554563
}
555564

556565
impl<M: Math> SamplerStats<M> for LowRankMassMatrixStrategy {
@@ -656,7 +665,7 @@ mod test {
656665
x.diagonal_mut().column_vector_mut().add_assign(x_diag);
657666
y.diagonal_mut().column_vector_mut().add_assign(y_diag);
658667

659-
let out = spd_mean(x, y);
668+
let out = spd_mean(x, y).expect("Failed to compute spd mean");
660669
let expected_diag = faer::col![1., 2., 4.];
661670
let mut expected = faer::Mat::zeros(3, 3);
662671
expected
@@ -684,7 +693,8 @@ mod test {
684693
//let grads: Mat<f64> = Mat::from_fn(20, 3, |_, _| rng.sample(distr));
685694
let grads = -(&draws);
686695

687-
let (vals, vecs) = estimate_mass_matrix(draws, grads, 0.0001);
696+
let (vals, vecs) =
697+
estimate_mass_matrix(draws, grads, 0.0001).expect("Failed to compute mass matrix");
688698
assert!(vals.iter().cloned().all(|x| x > 0.));
689699
assert!(mat_all_finite(&vecs.as_ref()));
690700

0 commit comments

Comments
 (0)