Skip to content

Commit 6e84e32

Browse files
committed
fix: store step size info in transform_adapt_strategy
1 parent 147ea76 commit 6e84e32

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/transform_adapt_strategy.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use nuts_derive::Storable;
2+
use nuts_storable::{HasDims, Storable};
23
use serde::Serialize;
34

45
use crate::adapt_strategy::CombinedCollector;
56
use crate::chain::AdaptStrategy;
67
use crate::hamiltonian::{Hamiltonian, Point};
78
use crate::nuts::{Collector, NutsOptions, SampleInfo};
8-
use crate::sampler_stats::SamplerStats;
9+
use crate::sampler_stats::{SamplerStats, StatsDims};
910
use crate::state::State;
1011
use crate::stepsize::AcceptanceRateCollector;
1112
use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
@@ -43,17 +44,23 @@ pub struct TransformAdaptation {
4344
}
4445

4546
#[derive(Debug, Storable)]
46-
pub struct Stats {
47+
pub struct Stats<P: HasDims, S: Storable<P>> {
4748
tuning: bool,
49+
#[storable(flatten)]
50+
pub step_size: S,
51+
#[storable(ignore)]
52+
_phantom: std::marker::PhantomData<fn() -> P>,
4853
}
4954

5055
impl<M: Math> SamplerStats<M> for TransformAdaptation {
51-
type Stats = Stats;
56+
type Stats = Stats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats>;
5257
type StatsOptions = ();
5358

54-
fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
59+
fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
5560
Stats {
5661
tuning: self.tuning,
62+
step_size: { self.step_size.extract_stats(math, ()) },
63+
_phantom: std::marker::PhantomData,
5764
}
5865
}
5966
}

0 commit comments

Comments
 (0)