@@ -17,7 +17,7 @@ use arrow::array::Array;
17
17
use numpy:: { PyArray1 , PyReadonlyArray1 } ;
18
18
use nuts_rs:: {
19
19
ChainProgress , DiagGradNutsSettings , LowRankNutsSettings , ProgressCallback , Sampler ,
20
- SamplerWaitResult , Trace , TransformedNutsSettings ,
20
+ SamplerWaitResult , StepSizeAdaptMethod , Trace , TransformedNutsSettings ,
21
21
} ;
22
22
use pyo3:: {
23
23
exceptions:: PyTimeoutError ,
@@ -276,22 +276,13 @@ impl PyNutsSettings {
276
276
fn initial_step ( & self ) -> f64 {
277
277
match & self . inner {
278
278
Settings :: Diag ( nuts_settings) => {
279
- nuts_settings
280
- . adapt_options
281
- . dual_average_options
282
- . initial_step
279
+ nuts_settings. adapt_options . step_size_settings . initial_step
283
280
}
284
281
Settings :: LowRank ( nuts_settings) => {
285
- nuts_settings
286
- . adapt_options
287
- . dual_average_options
288
- . initial_step
282
+ nuts_settings. adapt_options . step_size_settings . initial_step
289
283
}
290
284
Settings :: Transforming ( nuts_settings) => {
291
- nuts_settings
292
- . adapt_options
293
- . dual_average_options
294
- . initial_step
285
+ nuts_settings. adapt_options . step_size_settings . initial_step
295
286
}
296
287
}
297
288
}
@@ -300,22 +291,13 @@ impl PyNutsSettings {
300
291
fn set_initial_step ( & mut self , val : f64 ) {
301
292
match & mut self . inner {
302
293
Settings :: Diag ( nuts_settings) => {
303
- nuts_settings
304
- . adapt_options
305
- . dual_average_options
306
- . initial_step = val;
294
+ nuts_settings. adapt_options . step_size_settings . initial_step = val;
307
295
}
308
296
Settings :: LowRank ( nuts_settings) => {
309
- nuts_settings
310
- . adapt_options
311
- . dual_average_options
312
- . initial_step = val;
297
+ nuts_settings. adapt_options . step_size_settings . initial_step = val;
313
298
}
314
299
Settings :: Transforming ( nuts_settings) => {
315
- nuts_settings
316
- . adapt_options
317
- . dual_average_options
318
- . initial_step = val;
300
+ nuts_settings. adapt_options . step_size_settings . initial_step = val;
319
301
}
320
302
}
321
303
}
@@ -414,22 +396,13 @@ impl PyNutsSettings {
414
396
fn set_target_accept ( & self ) -> f64 {
415
397
match & self . inner {
416
398
Settings :: Diag ( nuts_settings) => {
417
- nuts_settings
418
- . adapt_options
419
- . dual_average_options
420
- . target_accept
399
+ nuts_settings. adapt_options . step_size_settings . target_accept
421
400
}
422
401
Settings :: LowRank ( nuts_settings) => {
423
- nuts_settings
424
- . adapt_options
425
- . dual_average_options
426
- . target_accept
402
+ nuts_settings. adapt_options . step_size_settings . target_accept
427
403
}
428
404
Settings :: Transforming ( nuts_settings) => {
429
- nuts_settings
430
- . adapt_options
431
- . dual_average_options
432
- . target_accept
405
+ nuts_settings. adapt_options . step_size_settings . target_accept
433
406
}
434
407
}
435
408
}
@@ -438,22 +411,13 @@ impl PyNutsSettings {
438
411
fn target_accept ( & mut self , val : f64 ) {
439
412
match & mut self . inner {
440
413
Settings :: Diag ( nuts_settings) => {
441
- nuts_settings
442
- . adapt_options
443
- . dual_average_options
444
- . target_accept = val
414
+ nuts_settings. adapt_options . step_size_settings . target_accept = val
445
415
}
446
416
Settings :: LowRank ( nuts_settings) => {
447
- nuts_settings
448
- . adapt_options
449
- . dual_average_options
450
- . target_accept = val
417
+ nuts_settings. adapt_options . step_size_settings . target_accept = val
451
418
}
452
419
Settings :: Transforming ( nuts_settings) => {
453
- nuts_settings
454
- . adapt_options
455
- . dual_average_options
456
- . target_accept = val
420
+ nuts_settings. adapt_options . step_size_settings . target_accept = val
457
421
}
458
422
}
459
423
}
@@ -654,6 +618,146 @@ impl PyNutsSettings {
654
618
}
655
619
Ok ( ( ) )
656
620
}
621
+
622
+ #[ getter]
623
+ fn step_size_adapt_method ( & self ) -> String {
624
+ let method = match & self . inner {
625
+ Settings :: LowRank ( inner) => inner. adapt_options . step_size_settings . adapt_options . method ,
626
+ Settings :: Diag ( inner) => inner. adapt_options . step_size_settings . adapt_options . method ,
627
+ Settings :: Transforming ( inner) => {
628
+ inner. adapt_options . step_size_settings . adapt_options . method
629
+ }
630
+ } ;
631
+
632
+ match method {
633
+ nuts_rs:: StepSizeAdaptMethod :: DualAverage => "dual_average" ,
634
+ nuts_rs:: StepSizeAdaptMethod :: Adam => "adam" ,
635
+ nuts_rs:: StepSizeAdaptMethod :: Fixed ( _) => "fixed" ,
636
+ }
637
+ . to_string ( )
638
+ }
639
+
640
+ #[ setter( step_size_adapt_method) ]
641
+ fn set_step_size_adapt_method ( & mut self , method : Py < PyAny > ) -> Result < ( ) > {
642
+ let method = Python :: with_gil ( |py| {
643
+ if let Ok ( method) = method. extract :: < String > ( py) {
644
+ match method. as_str ( ) {
645
+ "dual_average" => Ok ( StepSizeAdaptMethod :: DualAverage ) ,
646
+ "adam" => Ok ( StepSizeAdaptMethod :: Adam ) ,
647
+ _ => {
648
+ if let Ok ( step_size) = method. parse :: < f64 > ( ) {
649
+ Ok ( StepSizeAdaptMethod :: Fixed ( step_size) )
650
+ } else {
651
+ bail ! ( "step_size_adapt_method must be a positive float when using fixed step size" ) ;
652
+ }
653
+ }
654
+ }
655
+ } else {
656
+ bail ! ( "step_size_adapt_method must be a string" ) ;
657
+ }
658
+ } ) ?;
659
+
660
+ match & mut self . inner {
661
+ Settings :: LowRank ( inner) => {
662
+ inner. adapt_options . step_size_settings . adapt_options . method = method
663
+ }
664
+ Settings :: Diag ( inner) => {
665
+ inner. adapt_options . step_size_settings . adapt_options . method = method
666
+ }
667
+ Settings :: Transforming ( inner) => {
668
+ inner. adapt_options . step_size_settings . adapt_options . method = method
669
+ }
670
+ } ;
671
+ Ok ( ( ) )
672
+ }
673
+
674
+ #[ getter]
675
+ fn step_size_adam_learning_rate ( & self ) -> Option < f64 > {
676
+ match & self . inner {
677
+ Settings :: LowRank ( inner) => {
678
+ if let StepSizeAdaptMethod :: Adam =
679
+ inner. adapt_options . step_size_settings . adapt_options . method
680
+ {
681
+ Some (
682
+ inner
683
+ . adapt_options
684
+ . step_size_settings
685
+ . adapt_options
686
+ . adam
687
+ . learning_rate ,
688
+ )
689
+ } else {
690
+ None
691
+ }
692
+ }
693
+ Settings :: Diag ( inner) => {
694
+ if let StepSizeAdaptMethod :: Adam =
695
+ inner. adapt_options . step_size_settings . adapt_options . method
696
+ {
697
+ Some (
698
+ inner
699
+ . adapt_options
700
+ . step_size_settings
701
+ . adapt_options
702
+ . adam
703
+ . learning_rate ,
704
+ )
705
+ } else {
706
+ None
707
+ }
708
+ }
709
+ Settings :: Transforming ( inner) => {
710
+ if let StepSizeAdaptMethod :: Adam =
711
+ inner. adapt_options . step_size_settings . adapt_options . method
712
+ {
713
+ Some (
714
+ inner
715
+ . adapt_options
716
+ . step_size_settings
717
+ . adapt_options
718
+ . adam
719
+ . learning_rate ,
720
+ )
721
+ } else {
722
+ None
723
+ }
724
+ }
725
+ }
726
+ }
727
+
728
+ #[ setter( step_size_adam_learning_rate) ]
729
+ fn set_step_size_adam_learning_rate ( & mut self , val : Option < f64 > ) -> Result < ( ) > {
730
+ let Some ( val) = val else {
731
+ return Ok ( ( ) ) ;
732
+ } ;
733
+ match & mut self . inner {
734
+ Settings :: LowRank ( inner) => {
735
+ inner
736
+ . adapt_options
737
+ . step_size_settings
738
+ . adapt_options
739
+ . adam
740
+ . learning_rate = val
741
+ }
742
+ Settings :: Diag ( inner) => {
743
+ inner
744
+ . adapt_options
745
+ . step_size_settings
746
+ . adapt_options
747
+ . adam
748
+ . learning_rate = val
749
+ }
750
+ Settings :: Transforming ( inner) => {
751
+ inner
752
+ . adapt_options
753
+ . step_size_settings
754
+ . adapt_options
755
+ . adam
756
+ . learning_rate = val
757
+ }
758
+ } ;
759
+ Ok ( ( ) )
760
+ }
657
761
}
658
762
659
763
pub ( crate ) enum SamplerState {
0 commit comments