@@ -18,13 +18,13 @@ use crate::{
1818
1919fn mat_all_finite ( mat : & MatRef < f64 > ) -> bool {
2020 let mut ok = true ;
21- faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok = ok & val. is_finite ( ) ) ;
21+ faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok &= val. is_finite ( ) ) ;
2222 ok
2323}
2424
2525fn col_all_finite ( mat : & ColRef < f64 > ) -> bool {
2626 let mut ok = true ;
27- faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok = ok & val. is_finite ( ) ) ;
27+ faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok &= val. is_finite ( ) ) ;
2828 ok
2929}
3030
@@ -414,10 +414,22 @@ impl LowRankMassMatrixStrategy {
414414 grads. col_as_slice_mut ( i) . copy_from_slice ( & grad[ ..] ) ;
415415 }
416416
417+ let Some ( ( stds, vals, vecs) ) = self . compute_update ( draws, grads) else {
418+ return ;
419+ } ;
420+
421+ matrix. update ( math, stds, vals, vecs) ;
422+ }
423+
424+ fn compute_update (
425+ & self ,
426+ mut draws : Mat < f64 > ,
427+ mut grads : Mat < f64 > ,
428+ ) -> Option < ( Col < f64 > , Col < f64 > , Mat < f64 > ) > {
417429 let stds = rescale_points ( & mut draws, & mut grads) ;
418430
419- let svd_draws = draws. thin_svd ( ) . expect ( "Failed to compute SVD" ) ;
420- let svd_grads = grads. thin_svd ( ) . expect ( "Failed to compute SVD" ) ;
431+ let svd_draws = draws. thin_svd ( ) . ok ( ) ? ;
432+ let svd_grads = grads. thin_svd ( ) . ok ( ) ? ;
421433
422434 let subspace = faer:: concat![ [ svd_draws. U ( ) , svd_grads. U ( ) ] ] ;
423435
@@ -428,7 +440,7 @@ impl LowRankMassMatrixStrategy {
428440 let draws_proj = subspace_basis. transpose ( ) * ( & draws) ;
429441 let grads_proj = subspace_basis. transpose ( ) * ( & grads) ;
430442
431- let ( vals, vecs) = estimate_mass_matrix ( draws_proj, grads_proj, self . settings . gamma ) ;
443+ let ( vals, vecs) = estimate_mass_matrix ( draws_proj, grads_proj, self . settings . gamma ) ? ;
432444
433445 let filtered = vals
434446 . iter ( )
@@ -448,8 +460,7 @@ impl LowRankMassMatrixStrategy {
448460 . for_each ( |( mut col, vals) | col. copy_from ( vals) ) ;
449461
450462 let vecs = subspace_basis * vecs;
451-
452- matrix. update ( math, stds, vals, vecs) ;
463+ Some ( ( stds, vals, vecs) )
453464 }
454465}
455466
@@ -490,7 +501,11 @@ fn rescale_points(draws: &mut Mat<f64>, grads: &mut Mat<f64>) -> Col<f64> {
490501 stds
491502}
492503
493- fn estimate_mass_matrix ( draws : Mat < f64 > , grads : Mat < f64 > , gamma : f64 ) -> ( Col < f64 > , Mat < f64 > ) {
504+ fn estimate_mass_matrix (
505+ draws : Mat < f64 > ,
506+ grads : Mat < f64 > ,
507+ gamma : f64 ,
508+ ) -> Option < ( Col < f64 > , Mat < f64 > ) > {
494509 let mut cov_draws = ( & draws) * draws. transpose ( ) ;
495510 let mut cov_grads = ( & grads) * grads. transpose ( ) ;
496511
@@ -509,22 +524,18 @@ fn estimate_mass_matrix(draws: Mat<f64>, grads: Mat<f64>, gamma: f64) -> (Col<f6
509524 . iter_mut ( )
510525 . for_each ( |x| * x += 1f64 ) ;
511526
512- let mean = spd_mean ( cov_draws, cov_grads) ;
527+ let mean = spd_mean ( cov_draws, cov_grads) ? ;
513528
514- let mean_eig = mean
515- . self_adjoint_eigen ( faer:: Side :: Lower )
516- . expect ( "Failed to compute eigenvalue decomposition" ) ;
529+ let mean_eig = mean. self_adjoint_eigen ( faer:: Side :: Lower ) . ok ( ) ?;
517530
518- (
531+ Some ( (
519532 mean_eig. S ( ) . column_vector ( ) . to_owned ( ) ,
520533 mean_eig. U ( ) . to_owned ( ) ,
521- )
534+ ) )
522535}
523536
524- fn spd_mean ( cov_draws : Mat < f64 > , cov_grads : Mat < f64 > ) -> Mat < f64 > {
525- let eigs_grads = cov_grads
526- . self_adjoint_eigen ( faer:: Side :: Lower )
527- . expect ( "Failed to compute eigenvalue decomposition" ) ;
537+ fn spd_mean ( cov_draws : Mat < f64 > , cov_grads : Mat < f64 > ) -> Option < Mat < f64 > > {
538+ let eigs_grads = cov_grads. self_adjoint_eigen ( faer:: Side :: Lower ) . ok ( ) ?;
528539
529540 let u = eigs_grads. U ( ) ;
530541 let eigs = eigs_grads. S ( ) . column_vector ( ) . to_owned ( ) ;
@@ -534,9 +545,7 @@ fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Mat<f64> {
534545 let cov_grads_sqrt = u * eigs_sqrt. into_diagonal ( ) * u. transpose ( ) ;
535546 let m = ( & cov_grads_sqrt) * cov_draws * cov_grads_sqrt;
536547
537- let m_eig = m
538- . self_adjoint_eigen ( faer:: Side :: Lower )
539- . expect ( "Failed to compute eigenvalue decomposition" ) ;
548+ let m_eig = m. self_adjoint_eigen ( faer:: Side :: Lower ) . ok ( ) ?;
540549
541550 let m_u = m_eig. U ( ) ;
542551 let mut m_s = m_eig. S ( ) . column_vector ( ) . to_owned ( ) ;
@@ -550,7 +559,7 @@ fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Mat<f64> {
550559 . for_each ( |val| * val = val. sqrt ( ) . recip ( ) ) ;
551560 let grads_inv_sqrt = u * eigs_grads_inv. into_diagonal ( ) * u. transpose ( ) ;
552561
553- ( & grads_inv_sqrt) * m_sqrt * grads_inv_sqrt
562+ Some ( ( & grads_inv_sqrt) * m_sqrt * grads_inv_sqrt)
554563}
555564
556565impl < M : Math > SamplerStats < M > for LowRankMassMatrixStrategy {
@@ -656,7 +665,7 @@ mod test {
656665 x. diagonal_mut ( ) . column_vector_mut ( ) . add_assign ( x_diag) ;
657666 y. diagonal_mut ( ) . column_vector_mut ( ) . add_assign ( y_diag) ;
658667
659- let out = spd_mean ( x, y) ;
668+ let out = spd_mean ( x, y) . expect ( "Failed to compute spd mean" ) ;
660669 let expected_diag = faer:: col![ 1. , 2. , 4. ] ;
661670 let mut expected = faer:: Mat :: zeros ( 3 , 3 ) ;
662671 expected
@@ -684,7 +693,8 @@ mod test {
684693 //let grads: Mat<f64> = Mat::from_fn(20, 3, |_, _| rng.sample(distr));
685694 let grads = -( & draws) ;
686695
687- let ( vals, vecs) = estimate_mass_matrix ( draws, grads, 0.0001 ) ;
696+ let ( vals, vecs) =
697+ estimate_mass_matrix ( draws, grads, 0.0001 ) . expect ( "Failed to compute mass matrix" ) ;
688698 assert ! ( vals. iter( ) . cloned( ) . all( |x| x > 0. ) ) ;
689699 assert ! ( mat_all_finite( & vecs. as_ref( ) ) ) ;
690700
0 commit comments