@@ -18,13 +18,13 @@ use crate::{
18
18
19
19
fn mat_all_finite ( mat : & MatRef < f64 > ) -> bool {
20
20
let mut ok = true ;
21
- faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok = ok & val. is_finite ( ) ) ;
21
+ faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok &= val. is_finite ( ) ) ;
22
22
ok
23
23
}
24
24
25
25
fn col_all_finite ( mat : & ColRef < f64 > ) -> bool {
26
26
let mut ok = true ;
27
- faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok = ok & val. is_finite ( ) ) ;
27
+ faer:: zip!( mat) . for_each ( |faer:: unzip!( val) | ok &= val. is_finite ( ) ) ;
28
28
ok
29
29
}
30
30
@@ -414,10 +414,22 @@ impl LowRankMassMatrixStrategy {
414
414
grads. col_as_slice_mut ( i) . copy_from_slice ( & grad[ ..] ) ;
415
415
}
416
416
417
+ let Some ( ( stds, vals, vecs) ) = self . compute_update ( draws, grads) else {
418
+ return ;
419
+ } ;
420
+
421
+ matrix. update ( math, stds, vals, vecs) ;
422
+ }
423
+
424
+ fn compute_update (
425
+ & self ,
426
+ mut draws : Mat < f64 > ,
427
+ mut grads : Mat < f64 > ,
428
+ ) -> Option < ( Col < f64 > , Col < f64 > , Mat < f64 > ) > {
417
429
let stds = rescale_points ( & mut draws, & mut grads) ;
418
430
419
- let svd_draws = draws. thin_svd ( ) . expect ( "Failed to compute SVD" ) ;
420
- let svd_grads = grads. thin_svd ( ) . expect ( "Failed to compute SVD" ) ;
431
+ let svd_draws = draws. thin_svd ( ) . ok ( ) ? ;
432
+ let svd_grads = grads. thin_svd ( ) . ok ( ) ? ;
421
433
422
434
let subspace = faer:: concat![ [ svd_draws. U ( ) , svd_grads. U ( ) ] ] ;
423
435
@@ -428,7 +440,7 @@ impl LowRankMassMatrixStrategy {
428
440
let draws_proj = subspace_basis. transpose ( ) * ( & draws) ;
429
441
let grads_proj = subspace_basis. transpose ( ) * ( & grads) ;
430
442
431
- let ( vals, vecs) = estimate_mass_matrix ( draws_proj, grads_proj, self . settings . gamma ) ;
443
+ let ( vals, vecs) = estimate_mass_matrix ( draws_proj, grads_proj, self . settings . gamma ) ? ;
432
444
433
445
let filtered = vals
434
446
. iter ( )
@@ -448,8 +460,7 @@ impl LowRankMassMatrixStrategy {
448
460
. for_each ( |( mut col, vals) | col. copy_from ( vals) ) ;
449
461
450
462
let vecs = subspace_basis * vecs;
451
-
452
- matrix. update ( math, stds, vals, vecs) ;
463
+ Some ( ( stds, vals, vecs) )
453
464
}
454
465
}
455
466
@@ -490,7 +501,11 @@ fn rescale_points(draws: &mut Mat<f64>, grads: &mut Mat<f64>) -> Col<f64> {
490
501
stds
491
502
}
492
503
493
- fn estimate_mass_matrix ( draws : Mat < f64 > , grads : Mat < f64 > , gamma : f64 ) -> ( Col < f64 > , Mat < f64 > ) {
504
+ fn estimate_mass_matrix (
505
+ draws : Mat < f64 > ,
506
+ grads : Mat < f64 > ,
507
+ gamma : f64 ,
508
+ ) -> Option < ( Col < f64 > , Mat < f64 > ) > {
494
509
let mut cov_draws = ( & draws) * draws. transpose ( ) ;
495
510
let mut cov_grads = ( & grads) * grads. transpose ( ) ;
496
511
@@ -509,22 +524,18 @@ fn estimate_mass_matrix(draws: Mat<f64>, grads: Mat<f64>, gamma: f64) -> (Col<f6
509
524
. iter_mut ( )
510
525
. for_each ( |x| * x += 1f64 ) ;
511
526
512
- let mean = spd_mean ( cov_draws, cov_grads) ;
527
+ let mean = spd_mean ( cov_draws, cov_grads) ? ;
513
528
514
- let mean_eig = mean
515
- . self_adjoint_eigen ( faer:: Side :: Lower )
516
- . expect ( "Failed to compute eigenvalue decomposition" ) ;
529
+ let mean_eig = mean. self_adjoint_eigen ( faer:: Side :: Lower ) . ok ( ) ?;
517
530
518
- (
531
+ Some ( (
519
532
mean_eig. S ( ) . column_vector ( ) . to_owned ( ) ,
520
533
mean_eig. U ( ) . to_owned ( ) ,
521
- )
534
+ ) )
522
535
}
523
536
524
- fn spd_mean ( cov_draws : Mat < f64 > , cov_grads : Mat < f64 > ) -> Mat < f64 > {
525
- let eigs_grads = cov_grads
526
- . self_adjoint_eigen ( faer:: Side :: Lower )
527
- . expect ( "Failed to compute eigenvalue decomposition" ) ;
537
+ fn spd_mean ( cov_draws : Mat < f64 > , cov_grads : Mat < f64 > ) -> Option < Mat < f64 > > {
538
+ let eigs_grads = cov_grads. self_adjoint_eigen ( faer:: Side :: Lower ) . ok ( ) ?;
528
539
529
540
let u = eigs_grads. U ( ) ;
530
541
let eigs = eigs_grads. S ( ) . column_vector ( ) . to_owned ( ) ;
@@ -534,9 +545,7 @@ fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Mat<f64> {
534
545
let cov_grads_sqrt = u * eigs_sqrt. into_diagonal ( ) * u. transpose ( ) ;
535
546
let m = ( & cov_grads_sqrt) * cov_draws * cov_grads_sqrt;
536
547
537
- let m_eig = m
538
- . self_adjoint_eigen ( faer:: Side :: Lower )
539
- . expect ( "Failed to compute eigenvalue decomposition" ) ;
548
+ let m_eig = m. self_adjoint_eigen ( faer:: Side :: Lower ) . ok ( ) ?;
540
549
541
550
let m_u = m_eig. U ( ) ;
542
551
let mut m_s = m_eig. S ( ) . column_vector ( ) . to_owned ( ) ;
@@ -550,7 +559,7 @@ fn spd_mean(cov_draws: Mat<f64>, cov_grads: Mat<f64>) -> Mat<f64> {
550
559
. for_each ( |val| * val = val. sqrt ( ) . recip ( ) ) ;
551
560
let grads_inv_sqrt = u * eigs_grads_inv. into_diagonal ( ) * u. transpose ( ) ;
552
561
553
- ( & grads_inv_sqrt) * m_sqrt * grads_inv_sqrt
562
+ Some ( ( & grads_inv_sqrt) * m_sqrt * grads_inv_sqrt)
554
563
}
555
564
556
565
impl < M : Math > SamplerStats < M > for LowRankMassMatrixStrategy {
@@ -656,7 +665,7 @@ mod test {
656
665
x. diagonal_mut ( ) . column_vector_mut ( ) . add_assign ( x_diag) ;
657
666
y. diagonal_mut ( ) . column_vector_mut ( ) . add_assign ( y_diag) ;
658
667
659
- let out = spd_mean ( x, y) ;
668
+ let out = spd_mean ( x, y) . expect ( "Failed to compute spd mean" ) ;
660
669
let expected_diag = faer:: col![ 1. , 2. , 4. ] ;
661
670
let mut expected = faer:: Mat :: zeros ( 3 , 3 ) ;
662
671
expected
@@ -684,7 +693,8 @@ mod test {
684
693
//let grads: Mat<f64> = Mat::from_fn(20, 3, |_, _| rng.sample(distr));
685
694
let grads = -( & draws) ;
686
695
687
- let ( vals, vecs) = estimate_mass_matrix ( draws, grads, 0.0001 ) ;
696
+ let ( vals, vecs) =
697
+ estimate_mass_matrix ( draws, grads, 0.0001 ) . expect ( "Failed to compute mass matrix" ) ;
688
698
assert ! ( vals. iter( ) . cloned( ) . all( |x| x > 0. ) ) ;
689
699
assert ! ( mat_all_finite( & vecs. as_ref( ) ) ) ;
690
700
0 commit comments