@@ -16,8 +16,7 @@ use crate::{
1616 state:: State ,
1717 stepsize:: AcceptanceRateCollector ,
1818 stepsize_adapt:: {
19- DualAverageSettings , Stats as StepSizeStats , StatsBuilder as StepSizeStatsBuilder ,
20- Strategy as StepSizeStrategy ,
19+ DualAverageSettings , StatsBuilder as StepSizeStatsBuilder , Strategy as StepSizeStrategy ,
2120 } ,
2221 NutsError ,
2322} ;
@@ -63,20 +62,18 @@ impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
6362}
6463
6564impl < M : Math , A : MassMatrixAdaptStrategy < M > > SamplerStats < M > for GlobalStrategy < M , A > {
66- type Stats = CombinedStats < StepSizeStats , A :: Stats > ;
67- type Builder = CombinedStatsBuilder < StepSizeStatsBuilder , A :: Builder > ;
65+ type Builder = GlobalStrategyBuilder < A :: Builder > ;
66+ type StatOptions = < A as SamplerStats < M > > :: StatOptions ;
6867
69- fn current_stats ( & self , math : & mut M ) -> Self :: Stats {
70- CombinedStats {
71- stats1 : self . step_size . current_stats ( math) ,
72- stats2 : self . mass_matrix . current_stats ( math) ,
73- }
74- }
75-
76- fn new_builder ( & self , settings : & impl Settings , dim : usize ) -> Self :: Builder {
77- CombinedStatsBuilder {
78- stats1 : SamplerStats :: < M > :: new_builder ( & self . step_size , settings, dim) ,
79- stats2 : self . mass_matrix . new_builder ( settings, dim) ,
68+ fn new_builder (
69+ & self ,
70+ options : Self :: StatOptions ,
71+ settings : & impl Settings ,
72+ dim : usize ,
73+ ) -> Self :: Builder {
74+ GlobalStrategyBuilder {
75+ step_size : SamplerStats :: < M > :: new_builder ( & self . step_size , ( ) , settings, dim) ,
76+ mass_matrix : self . mass_matrix . new_builder ( options, settings, dim) ,
8077 }
8178 }
8279}
@@ -218,33 +215,37 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
218215 fn is_tuning ( & self ) -> bool {
219216 self . tuning
220217 }
221- }
222218
223- #[ derive( Debug , Clone ) ]
224- pub struct CombinedStats < D1 , D2 > {
225- pub stats1 : D1 ,
226- pub stats2 : D2 ,
219+ fn last_num_steps ( & self ) -> u64 {
220+ self . step_size . last_n_steps
221+ }
227222}
228223
229- #[ derive( Clone ) ]
230- pub struct CombinedStatsBuilder < B1 , B2 > {
231- pub stats1 : B1 ,
232- pub stats2 : B2 ,
224+ pub struct GlobalStrategyBuilder < B > {
225+ pub step_size : StepSizeStatsBuilder ,
226+ pub mass_matrix : B ,
233227}
234228
235- impl < S1 , S2 , B1 , B2 > StatTraceBuilder < CombinedStats < S1 , S2 > > for CombinedStatsBuilder < B1 , B2 >
229+ impl < M : Math , A > StatTraceBuilder < M , GlobalStrategy < M , A > > for GlobalStrategyBuilder < A :: Builder >
236230where
237- B1 : StatTraceBuilder < S1 > ,
238- B2 : StatTraceBuilder < S2 > ,
231+ A : MassMatrixAdaptStrategy < M > ,
239232{
240- fn append_value ( & mut self , value : CombinedStats < S1 , S2 > ) {
241- self . stats1 . append_value ( value. stats1 ) ;
242- self . stats2 . append_value ( value. stats2 ) ;
233+ fn append_value ( & mut self , math : Option < & mut M > , value : & GlobalStrategy < M , A > ) {
234+ let math = math. expect ( "Smapler stats need math" ) ;
235+ self . step_size . append_value ( Some ( math) , & value. step_size ) ;
236+ self . mass_matrix
237+ . append_value ( Some ( math) , & value. mass_matrix ) ;
243238 }
244239
245240 fn finalize ( self ) -> Option < StructArray > {
246- let Self { stats1, stats2 } = self ;
247- match ( stats1. finalize ( ) , stats2. finalize ( ) ) {
241+ let Self {
242+ step_size,
243+ mass_matrix,
244+ } = self ;
245+ match (
246+ StatTraceBuilder :: < M , _ > :: finalize ( step_size) ,
247+ mass_matrix. finalize ( ) ,
248+ ) {
248249 ( None , None ) => None ,
249250 ( Some ( stats1) , None ) => Some ( stats1) ,
250251 ( None , Some ( stats2) ) => Some ( stats2) ,
@@ -266,8 +267,14 @@ where
266267 }
267268
268269 fn inspect ( & self ) -> Option < StructArray > {
269- let Self { stats1, stats2 } = self ;
270- match ( stats1. inspect ( ) , stats2. inspect ( ) ) {
270+ let Self {
271+ step_size,
272+ mass_matrix,
273+ } = self ;
274+ match (
275+ StatTraceBuilder :: < M , _ > :: inspect ( step_size) ,
276+ mass_matrix. inspect ( ) ,
277+ ) {
271278 ( None , None ) => None ,
272279 ( Some ( stats1) , None ) => Some ( stats1) ,
273280 ( None , Some ( stats2) ) => Some ( stats2) ,
@@ -374,6 +381,7 @@ pub mod test_logps {
374381
375382 #[ derive( Error , Debug ) ]
376383 pub enum NormalLogpError { }
384+
377385 impl LogpError for NormalLogpError {
378386 fn is_recoverable ( & self ) -> bool {
379387 false
@@ -438,6 +446,7 @@ pub mod test_logps {
438446 _rng : & mut R ,
439447 _untransformed_positions : impl Iterator < Item = & ' a [ f64 ] > ,
440448 _untransformed_gradients : impl Iterator < Item = & ' a [ f64 ] > ,
449+ _untransformed_logp : impl Iterator < Item = & ' a f64 > ,
441450 _params : & ' a mut Self :: TransformParams ,
442451 ) -> Result < ( ) , Self :: LogpError > {
443452 unimplemented ! ( )
0 commit comments