@@ -16,8 +16,7 @@ use crate::{
16
16
state:: State ,
17
17
stepsize:: AcceptanceRateCollector ,
18
18
stepsize_adapt:: {
19
- DualAverageSettings , Stats as StepSizeStats , StatsBuilder as StepSizeStatsBuilder ,
20
- Strategy as StepSizeStrategy ,
19
+ DualAverageSettings , StatsBuilder as StepSizeStatsBuilder , Strategy as StepSizeStrategy ,
21
20
} ,
22
21
NutsError ,
23
22
} ;
@@ -63,20 +62,18 @@ impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
63
62
}
64
63
65
64
impl < M : Math , A : MassMatrixAdaptStrategy < M > > SamplerStats < M > for GlobalStrategy < M , A > {
66
- type Stats = CombinedStats < StepSizeStats , A :: Stats > ;
67
- type Builder = CombinedStatsBuilder < StepSizeStatsBuilder , A :: Builder > ;
65
+ type Builder = GlobalStrategyBuilder < A :: Builder > ;
66
+ type StatOptions = < A as SamplerStats < M > > :: StatOptions ;
68
67
69
- fn current_stats ( & self , math : & mut M ) -> Self :: Stats {
70
- CombinedStats {
71
- stats1 : self . step_size . current_stats ( math) ,
72
- stats2 : self . mass_matrix . current_stats ( math) ,
73
- }
74
- }
75
-
76
- fn new_builder ( & self , settings : & impl Settings , dim : usize ) -> Self :: Builder {
77
- CombinedStatsBuilder {
78
- stats1 : SamplerStats :: < M > :: new_builder ( & self . step_size , settings, dim) ,
79
- stats2 : self . mass_matrix . new_builder ( settings, dim) ,
68
+ fn new_builder (
69
+ & self ,
70
+ options : Self :: StatOptions ,
71
+ settings : & impl Settings ,
72
+ dim : usize ,
73
+ ) -> Self :: Builder {
74
+ GlobalStrategyBuilder {
75
+ step_size : SamplerStats :: < M > :: new_builder ( & self . step_size , ( ) , settings, dim) ,
76
+ mass_matrix : self . mass_matrix . new_builder ( options, settings, dim) ,
80
77
}
81
78
}
82
79
}
@@ -218,33 +215,37 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
218
215
fn is_tuning ( & self ) -> bool {
219
216
self . tuning
220
217
}
221
- }
222
218
223
- #[ derive( Debug , Clone ) ]
224
- pub struct CombinedStats < D1 , D2 > {
225
- pub stats1 : D1 ,
226
- pub stats2 : D2 ,
219
+ fn last_num_steps ( & self ) -> u64 {
220
+ self . step_size . last_n_steps
221
+ }
227
222
}
228
223
229
- #[ derive( Clone ) ]
230
- pub struct CombinedStatsBuilder < B1 , B2 > {
231
- pub stats1 : B1 ,
232
- pub stats2 : B2 ,
224
+ pub struct GlobalStrategyBuilder < B > {
225
+ pub step_size : StepSizeStatsBuilder ,
226
+ pub mass_matrix : B ,
233
227
}
234
228
235
- impl < S1 , S2 , B1 , B2 > StatTraceBuilder < CombinedStats < S1 , S2 > > for CombinedStatsBuilder < B1 , B2 >
229
+ impl < M : Math , A > StatTraceBuilder < M , GlobalStrategy < M , A > > for GlobalStrategyBuilder < A :: Builder >
236
230
where
237
- B1 : StatTraceBuilder < S1 > ,
238
- B2 : StatTraceBuilder < S2 > ,
231
+ A : MassMatrixAdaptStrategy < M > ,
239
232
{
240
- fn append_value ( & mut self , value : CombinedStats < S1 , S2 > ) {
241
- self . stats1 . append_value ( value. stats1 ) ;
242
- self . stats2 . append_value ( value. stats2 ) ;
233
+ fn append_value ( & mut self , math : Option < & mut M > , value : & GlobalStrategy < M , A > ) {
234
+ let math = math. expect ( "Smapler stats need math" ) ;
235
+ self . step_size . append_value ( Some ( math) , & value. step_size ) ;
236
+ self . mass_matrix
237
+ . append_value ( Some ( math) , & value. mass_matrix ) ;
243
238
}
244
239
245
240
fn finalize ( self ) -> Option < StructArray > {
246
- let Self { stats1, stats2 } = self ;
247
- match ( stats1. finalize ( ) , stats2. finalize ( ) ) {
241
+ let Self {
242
+ step_size,
243
+ mass_matrix,
244
+ } = self ;
245
+ match (
246
+ StatTraceBuilder :: < M , _ > :: finalize ( step_size) ,
247
+ mass_matrix. finalize ( ) ,
248
+ ) {
248
249
( None , None ) => None ,
249
250
( Some ( stats1) , None ) => Some ( stats1) ,
250
251
( None , Some ( stats2) ) => Some ( stats2) ,
@@ -266,8 +267,14 @@ where
266
267
}
267
268
268
269
fn inspect ( & self ) -> Option < StructArray > {
269
- let Self { stats1, stats2 } = self ;
270
- match ( stats1. inspect ( ) , stats2. inspect ( ) ) {
270
+ let Self {
271
+ step_size,
272
+ mass_matrix,
273
+ } = self ;
274
+ match (
275
+ StatTraceBuilder :: < M , _ > :: inspect ( step_size) ,
276
+ mass_matrix. inspect ( ) ,
277
+ ) {
271
278
( None , None ) => None ,
272
279
( Some ( stats1) , None ) => Some ( stats1) ,
273
280
( None , Some ( stats2) ) => Some ( stats2) ,
@@ -374,6 +381,7 @@ pub mod test_logps {
374
381
375
382
#[ derive( Error , Debug ) ]
376
383
pub enum NormalLogpError { }
384
+
377
385
impl LogpError for NormalLogpError {
378
386
fn is_recoverable ( & self ) -> bool {
379
387
false
@@ -438,6 +446,7 @@ pub mod test_logps {
438
446
_rng : & mut R ,
439
447
_untransformed_positions : impl Iterator < Item = & ' a [ f64 ] > ,
440
448
_untransformed_gradients : impl Iterator < Item = & ' a [ f64 ] > ,
449
+ _untransformed_logp : impl Iterator < Item = & ' a f64 > ,
441
450
_params : & ' a mut Self :: TransformParams ,
442
451
) -> Result < ( ) , Self :: LogpError > {
443
452
unimplemented ! ( )
0 commit comments