Skip to content

Commit 31294e1

Browse files
committed
feat: support step size adaptation method
1 parent df2162d commit 31294e1

File tree

1 file changed

+153
-49
lines changed

1 file changed

+153
-49
lines changed

src/wrapper.rs

Lines changed: 153 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use arrow::array::Array;
1717
use numpy::{PyArray1, PyReadonlyArray1};
1818
use nuts_rs::{
1919
ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler,
20-
SamplerWaitResult, Trace, TransformedNutsSettings,
20+
SamplerWaitResult, StepSizeAdaptMethod, Trace, TransformedNutsSettings,
2121
};
2222
use pyo3::{
2323
exceptions::PyTimeoutError,
@@ -276,22 +276,13 @@ impl PyNutsSettings {
276276
fn initial_step(&self) -> f64 {
277277
match &self.inner {
278278
Settings::Diag(nuts_settings) => {
279-
nuts_settings
280-
.adapt_options
281-
.dual_average_options
282-
.initial_step
279+
nuts_settings.adapt_options.step_size_settings.initial_step
283280
}
284281
Settings::LowRank(nuts_settings) => {
285-
nuts_settings
286-
.adapt_options
287-
.dual_average_options
288-
.initial_step
282+
nuts_settings.adapt_options.step_size_settings.initial_step
289283
}
290284
Settings::Transforming(nuts_settings) => {
291-
nuts_settings
292-
.adapt_options
293-
.dual_average_options
294-
.initial_step
285+
nuts_settings.adapt_options.step_size_settings.initial_step
295286
}
296287
}
297288
}
@@ -300,22 +291,13 @@ impl PyNutsSettings {
300291
fn set_initial_step(&mut self, val: f64) {
301292
match &mut self.inner {
302293
Settings::Diag(nuts_settings) => {
303-
nuts_settings
304-
.adapt_options
305-
.dual_average_options
306-
.initial_step = val;
294+
nuts_settings.adapt_options.step_size_settings.initial_step = val;
307295
}
308296
Settings::LowRank(nuts_settings) => {
309-
nuts_settings
310-
.adapt_options
311-
.dual_average_options
312-
.initial_step = val;
297+
nuts_settings.adapt_options.step_size_settings.initial_step = val;
313298
}
314299
Settings::Transforming(nuts_settings) => {
315-
nuts_settings
316-
.adapt_options
317-
.dual_average_options
318-
.initial_step = val;
300+
nuts_settings.adapt_options.step_size_settings.initial_step = val;
319301
}
320302
}
321303
}
@@ -414,22 +396,13 @@ impl PyNutsSettings {
414396
fn set_target_accept(&self) -> f64 {
415397
match &self.inner {
416398
Settings::Diag(nuts_settings) => {
417-
nuts_settings
418-
.adapt_options
419-
.dual_average_options
420-
.target_accept
399+
nuts_settings.adapt_options.step_size_settings.target_accept
421400
}
422401
Settings::LowRank(nuts_settings) => {
423-
nuts_settings
424-
.adapt_options
425-
.dual_average_options
426-
.target_accept
402+
nuts_settings.adapt_options.step_size_settings.target_accept
427403
}
428404
Settings::Transforming(nuts_settings) => {
429-
nuts_settings
430-
.adapt_options
431-
.dual_average_options
432-
.target_accept
405+
nuts_settings.adapt_options.step_size_settings.target_accept
433406
}
434407
}
435408
}
@@ -438,22 +411,13 @@ impl PyNutsSettings {
438411
fn target_accept(&mut self, val: f64) {
439412
match &mut self.inner {
440413
Settings::Diag(nuts_settings) => {
441-
nuts_settings
442-
.adapt_options
443-
.dual_average_options
444-
.target_accept = val
414+
nuts_settings.adapt_options.step_size_settings.target_accept = val
445415
}
446416
Settings::LowRank(nuts_settings) => {
447-
nuts_settings
448-
.adapt_options
449-
.dual_average_options
450-
.target_accept = val
417+
nuts_settings.adapt_options.step_size_settings.target_accept = val
451418
}
452419
Settings::Transforming(nuts_settings) => {
453-
nuts_settings
454-
.adapt_options
455-
.dual_average_options
456-
.target_accept = val
420+
nuts_settings.adapt_options.step_size_settings.target_accept = val
457421
}
458422
}
459423
}
@@ -654,6 +618,146 @@ impl PyNutsSettings {
654618
}
655619
Ok(())
656620
}
621+
622+
#[getter]
623+
fn step_size_adapt_method(&self) -> String {
624+
let method = match &self.inner {
625+
Settings::LowRank(inner) => inner.adapt_options.step_size_settings.adapt_options.method,
626+
Settings::Diag(inner) => inner.adapt_options.step_size_settings.adapt_options.method,
627+
Settings::Transforming(inner) => {
628+
inner.adapt_options.step_size_settings.adapt_options.method
629+
}
630+
};
631+
632+
match method {
633+
nuts_rs::StepSizeAdaptMethod::DualAverage => "dual_average",
634+
nuts_rs::StepSizeAdaptMethod::Adam => "adam",
635+
nuts_rs::StepSizeAdaptMethod::Fixed(_) => "fixed",
636+
}
637+
.to_string()
638+
}
639+
640+
#[setter(step_size_adapt_method)]
641+
fn set_step_size_adapt_method(&mut self, method: Py<PyAny>) -> Result<()> {
642+
let method = Python::with_gil(|py| {
643+
if let Ok(method) = method.extract::<String>(py) {
644+
match method.as_str() {
645+
"dual_average" => Ok(StepSizeAdaptMethod::DualAverage),
646+
"adam" => Ok(StepSizeAdaptMethod::Adam),
647+
_ => {
648+
if let Ok(step_size) = method.parse::<f64>() {
649+
Ok(StepSizeAdaptMethod::Fixed(step_size))
650+
} else {
651+
bail!("step_size_adapt_method must be a positive float when using fixed step size");
652+
}
653+
}
654+
}
655+
} else {
656+
bail!("step_size_adapt_method must be a string");
657+
}
658+
})?;
659+
660+
match &mut self.inner {
661+
Settings::LowRank(inner) => {
662+
inner.adapt_options.step_size_settings.adapt_options.method = method
663+
}
664+
Settings::Diag(inner) => {
665+
inner.adapt_options.step_size_settings.adapt_options.method = method
666+
}
667+
Settings::Transforming(inner) => {
668+
inner.adapt_options.step_size_settings.adapt_options.method = method
669+
}
670+
};
671+
Ok(())
672+
}
673+
674+
#[getter]
675+
fn step_size_adam_learning_rate(&self) -> Option<f64> {
676+
match &self.inner {
677+
Settings::LowRank(inner) => {
678+
if let StepSizeAdaptMethod::Adam =
679+
inner.adapt_options.step_size_settings.adapt_options.method
680+
{
681+
Some(
682+
inner
683+
.adapt_options
684+
.step_size_settings
685+
.adapt_options
686+
.adam
687+
.learning_rate,
688+
)
689+
} else {
690+
None
691+
}
692+
}
693+
Settings::Diag(inner) => {
694+
if let StepSizeAdaptMethod::Adam =
695+
inner.adapt_options.step_size_settings.adapt_options.method
696+
{
697+
Some(
698+
inner
699+
.adapt_options
700+
.step_size_settings
701+
.adapt_options
702+
.adam
703+
.learning_rate,
704+
)
705+
} else {
706+
None
707+
}
708+
}
709+
Settings::Transforming(inner) => {
710+
if let StepSizeAdaptMethod::Adam =
711+
inner.adapt_options.step_size_settings.adapt_options.method
712+
{
713+
Some(
714+
inner
715+
.adapt_options
716+
.step_size_settings
717+
.adapt_options
718+
.adam
719+
.learning_rate,
720+
)
721+
} else {
722+
None
723+
}
724+
}
725+
}
726+
}
727+
728+
#[setter(step_size_adam_learning_rate)]
729+
fn set_step_size_adam_learning_rate(&mut self, val: Option<f64>) -> Result<()> {
730+
let Some(val) = val else {
731+
return Ok(());
732+
};
733+
match &mut self.inner {
734+
Settings::LowRank(inner) => {
735+
inner
736+
.adapt_options
737+
.step_size_settings
738+
.adapt_options
739+
.adam
740+
.learning_rate = val
741+
}
742+
Settings::Diag(inner) => {
743+
inner
744+
.adapt_options
745+
.step_size_settings
746+
.adapt_options
747+
.adam
748+
.learning_rate = val
749+
}
750+
Settings::Transforming(inner) => {
751+
inner
752+
.adapt_options
753+
.step_size_settings
754+
.adapt_options
755+
.adam
756+
.learning_rate = val
757+
}
758+
};
759+
Ok(())
760+
}
657761
}
658762

659763
pub(crate) enum SamplerState {

0 commit comments

Comments
 (0)