|
1 | 1 | use nuts_derive::Storable; |
| 2 | +use nuts_storable::{HasDims, Storable}; |
2 | 3 | use serde::Serialize; |
3 | 4 |
|
4 | 5 | use crate::adapt_strategy::CombinedCollector; |
5 | 6 | use crate::chain::AdaptStrategy; |
6 | 7 | use crate::hamiltonian::{Hamiltonian, Point}; |
7 | 8 | use crate::nuts::{Collector, NutsOptions, SampleInfo}; |
8 | | -use crate::sampler_stats::SamplerStats; |
| 9 | +use crate::sampler_stats::{SamplerStats, StatsDims}; |
9 | 10 | use crate::state::State; |
10 | 11 | use crate::stepsize::AcceptanceRateCollector; |
11 | 12 | use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy}; |
@@ -43,17 +44,23 @@ pub struct TransformAdaptation { |
43 | 44 | } |
44 | 45 |
|
45 | 46 | #[derive(Debug, Storable)] |
46 | | -pub struct Stats { |
| 47 | +pub struct Stats<P: HasDims, S: Storable<P>> { |
47 | 48 | tuning: bool, |
| 49 | + #[storable(flatten)] |
| 50 | + pub step_size: S, |
| 51 | + #[storable(ignore)] |
| 52 | + _phantom: std::marker::PhantomData<fn() -> P>, |
48 | 53 | } |
49 | 54 |
|
50 | 55 | impl<M: Math> SamplerStats<M> for TransformAdaptation { |
51 | | - type Stats = Stats; |
| 56 | + type Stats = Stats<StatsDims, <StepSizeStrategy as SamplerStats<M>>::Stats>; |
52 | 57 | type StatsOptions = (); |
53 | 58 |
|
54 | | - fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { |
| 59 | + fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { |
55 | 60 | Stats { |
56 | 61 | tuning: self.tuning, |
| 62 | + step_size: { self.step_size.extract_stats(math, ()) }, |
| 63 | + _phantom: std::marker::PhantomData, |
57 | 64 | } |
58 | 65 | } |
59 | 66 | } |
|
0 commit comments