@@ -5,25 +5,26 @@ use itertools::Itertools;
55use rand:: Rng ;
66
77use crate :: {
8+ chain:: AdaptStrategy ,
9+ hamiltonian:: { DivergenceInfo , Hamiltonian , Point } ,
810 mass_matrix_adapt:: MassMatrixAdaptStrategy ,
911 math_base:: Math ,
10- nuts:: { AdaptStats , AdaptStrategy , Collector , NutsOptions } ,
12+ nuts:: { Collector , NutsOptions } ,
1113 sampler:: Settings ,
14+ sampler_stats:: { SamplerStats , StatTraceBuilder } ,
1215 state:: State ,
1316 stepsize:: AcceptanceRateCollector ,
1417 stepsize_adapt:: {
1518 DualAverageSettings , Stats as StepSizeStats , StatsBuilder as StepSizeStatsBuilder ,
1619 Strategy as StepSizeStrategy ,
1720 } ,
18- DivergenceInfo ,
21+ NutsError ,
1922} ;
2023
21- use crate :: nuts:: { SamplerStats , StatTraceBuilder } ;
22-
2324pub struct GlobalStrategy < M : Math , A : MassMatrixAdaptStrategy < M > > {
2425 step_size : StepSizeStrategy ,
2526 mass_matrix : A ,
26- options : AdaptOptions < A :: Options > ,
27+ options : EuclideanAdaptOptions < A :: Options > ,
2728 num_tune : u64 ,
2829 // The number of draws in the the early window
2930 early_end : u64 ,
@@ -36,7 +37,7 @@ pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
3637}
3738
3839#[ derive( Debug , Clone , Copy ) ]
39- pub struct AdaptOptions < S : Debug + Default > {
40+ pub struct EuclideanAdaptOptions < S : Debug + Default > {
4041 pub dual_average_options : DualAverageSettings ,
4142 pub mass_matrix_options : S ,
4243 pub early_window : f64 ,
@@ -46,7 +47,7 @@ pub struct AdaptOptions<S: Debug + Default> {
4647 pub mass_matrix_update_freq : u64 ,
4748}
4849
49- impl < S : Debug + Default > Default for AdaptOptions < S > {
50+ impl < S : Debug + Default > Default for EuclideanAdaptOptions < S > {
5051 fn default ( ) -> Self {
5152 Self {
5253 dual_average_options : DualAverageSettings :: default ( ) ,
@@ -79,16 +80,15 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<
7980 }
8081}
8182
82- impl < M : Math , A : MassMatrixAdaptStrategy < M > > AdaptStats < M > for GlobalStrategy < M , A > {
83- fn num_grad_evals ( stats : & Self :: Stats ) -> usize {
84- stats. stats1 . n_steps as usize
85- }
86- }
87-
8883impl < M : Math , A : MassMatrixAdaptStrategy < M > > AdaptStrategy < M > for GlobalStrategy < M , A > {
89- type Potential = A :: Potential ;
90- type Collector = CombinedCollector < M , AcceptanceRateCollector , A :: Collector > ;
91- type Options = AdaptOptions < A :: Options > ;
84+ type Hamiltonian = A :: Hamiltonian ;
85+ type Collector = CombinedCollector <
86+ M ,
87+ <Self :: Hamiltonian as Hamiltonian < M > >:: Point ,
88+ AcceptanceRateCollector ,
89+ A :: Collector ,
90+ > ;
91+ type Options = EuclideanAdaptOptions < A :: Options > ;
9292
9393 fn new ( math : & mut M , options : Self :: Options , num_tune : u64 ) -> Self {
9494 let num_tune_f = num_tune as f64 ;
@@ -115,29 +115,32 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
115115 & mut self ,
116116 math : & mut M ,
117117 options : & mut NutsOptions ,
118- potential : & mut Self :: Potential ,
119- state : & State < M > ,
118+ hamiltonian : & mut Self :: Hamiltonian ,
119+ state : & State < M , < Self :: Hamiltonian as Hamiltonian < M > > :: Point > ,
120120 rng : & mut R ,
121- ) {
122- self . mass_matrix . init ( math, options, potential, state, rng) ;
123- self . step_size . init ( math, options, potential, state, rng) ;
121+ ) -> Result < ( ) , NutsError > {
122+ self . mass_matrix
123+ . init ( math, options, hamiltonian, state, rng) ?;
124+ self . step_size
125+ . init ( math, options, hamiltonian, state, rng) ?;
126+ Ok ( ( ) )
124127 }
125128
126129 fn adapt < R : Rng + ?Sized > (
127130 & mut self ,
128131 math : & mut M ,
129132 options : & mut NutsOptions ,
130- potential : & mut Self :: Potential ,
133+ hamiltonian : & mut Self :: Hamiltonian ,
131134 draw : u64 ,
132135 collector : & Self :: Collector ,
133- state : & State < M > ,
136+ state : & State < M , < Self :: Hamiltonian as Hamiltonian < M > > :: Point > ,
134137 rng : & mut R ,
135- ) {
138+ ) -> Result < ( ) , NutsError > {
136139 self . step_size . update ( & collector. collector1 ) ;
137140
138141 if draw >= self . num_tune {
139142 self . tuning = false ;
140- return ;
143+ return Ok ( ( ) ) ;
141144 }
142145
143146 if draw < self . final_step_size_window {
@@ -165,7 +168,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
165168 let did_change = if force_update
166169 | ( draw - self . last_update >= self . options . mass_matrix_update_freq )
167170 {
168- self . mass_matrix . update_potential ( math, potential )
171+ self . mass_matrix . update_potential ( math, hamiltonian )
169172 } else {
170173 false
171174 } ;
@@ -183,24 +186,25 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
183186 // First time we change the mass matrix
184187 if did_change & self . has_initial_mass_matrix {
185188 self . has_initial_mass_matrix = false ;
186- self . step_size . init ( math, options, potential, state, rng) ;
189+ self . step_size
190+ . init ( math, options, hamiltonian, state, rng) ?;
187191 } else {
188- self . step_size . update_stepsize ( potential , false )
192+ self . step_size . update_stepsize ( hamiltonian , false )
189193 }
190- return ;
194+ return Ok ( ( ) ) ;
191195 }
192196
193197 self . step_size . update_estimator_late ( ) ;
194198 let is_last = draw == self . num_tune - 1 ;
195- self . step_size . update_stepsize ( potential, is_last) ;
199+ self . step_size . update_stepsize ( hamiltonian, is_last) ;
200+ Ok ( ( ) )
196201 }
197202
198203 fn new_collector ( & self , math : & mut M ) -> Self :: Collector {
199- CombinedCollector {
200- collector1 : self . step_size . new_collector ( ) ,
201- collector2 : self . mass_matrix . new_collector ( math) ,
202- _phantom : PhantomData ,
203- }
204+ Self :: Collector :: new (
205+ self . step_size . new_collector ( ) ,
206+ self . mass_matrix . new_collector ( math) ,
207+ )
204208 }
205209
206210 fn is_tuning ( & self ) -> bool {
@@ -277,22 +281,48 @@ where
277281 }
278282}
279283
280- pub struct CombinedCollector < M : Math , C1 : Collector < M > , C2 : Collector < M > > {
281- collector1 : C1 ,
282- collector2 : C2 ,
284+ pub struct CombinedCollector < M , P , C1 , C2 >
285+ where
286+ M : Math ,
287+ P : Point < M > ,
288+ C1 : Collector < M , P > ,
289+ C2 : Collector < M , P > ,
290+ {
291+ pub collector1 : C1 ,
292+ pub collector2 : C2 ,
283293 _phantom : PhantomData < M > ,
294+ _phantom2 : PhantomData < P > ,
284295}
285296
286- impl < M : Math , C1 , C2 > Collector < M > for CombinedCollector < M , C1 , C2 >
297+ impl < M , P , C1 , C2 > CombinedCollector < M , P , C1 , C2 >
287298where
288- C1 : Collector < M > ,
289- C2 : Collector < M > ,
299+ M : Math ,
300+ P : Point < M > ,
301+ C1 : Collector < M , P > ,
302+ C2 : Collector < M , P > ,
303+ {
304+ pub fn new ( collector1 : C1 , collector2 : C2 ) -> Self {
305+ CombinedCollector {
306+ collector1,
307+ collector2,
308+ _phantom : PhantomData ,
309+ _phantom2 : PhantomData ,
310+ }
311+ }
312+ }
313+
314+ impl < M , P , C1 , C2 > Collector < M , P > for CombinedCollector < M , P , C1 , C2 >
315+ where
316+ M : Math ,
317+ P : Point < M > ,
318+ C1 : Collector < M , P > ,
319+ C2 : Collector < M , P > ,
290320{
291321 fn register_leapfrog (
292322 & mut self ,
293323 math : & mut M ,
294- start : & State < M > ,
295- end : & State < M > ,
324+ start : & State < M , P > ,
325+ end : & State < M , P > ,
296326 divergence_info : Option < & DivergenceInfo > ,
297327 ) {
298328 self . collector1
@@ -301,15 +331,15 @@ where
301331 . register_leapfrog ( math, start, end, divergence_info) ;
302332 }
303333
304- fn register_draw ( & mut self , math : & mut M , state : & State < M > , info : & crate :: nuts:: SampleInfo ) {
334+ fn register_draw ( & mut self , math : & mut M , state : & State < M , P > , info : & crate :: nuts:: SampleInfo ) {
305335 self . collector1 . register_draw ( math, state, info) ;
306336 self . collector2 . register_draw ( math, state, info) ;
307337 }
308338
309339 fn register_init (
310340 & mut self ,
311341 math : & mut M ,
312- state : & State < M > ,
342+ state : & State < M , P > ,
313343 options : & crate :: nuts:: NutsOptions ,
314344 ) {
315345 self . collector1 . register_init ( math, state, options) ;
@@ -319,7 +349,7 @@ where
319349
320350#[ cfg( test) ]
321351pub mod test_logps {
322- use crate :: { cpu_math:: CpuLogpFunc , nuts :: LogpError } ;
352+ use crate :: { cpu_math:: CpuLogpFunc , math_base :: LogpError } ;
323353 use thiserror:: Error ;
324354
325355 #[ derive( Clone , Debug ) ]
@@ -344,6 +374,7 @@ pub mod test_logps {
344374
345375 impl CpuLogpFunc for NormalLogp {
346376 type LogpError = NormalLogpError ;
377+ type TransformParams = ( ) ;
347378
348379 fn dim ( & self ) -> usize {
349380 self . dim
@@ -360,6 +391,50 @@ pub mod test_logps {
360391 }
361392 Ok ( logp)
362393 }
394+
395+ fn inv_transform_normalize (
396+ & mut self ,
397+ _params : & Self :: TransformParams ,
398+ _untransformed_position : & [ f64 ] ,
399+ _untransofrmed_gradient : & [ f64 ] ,
400+ _transformed_position : & mut [ f64 ] ,
401+ _transformed_gradient : & mut [ f64 ] ,
402+ ) -> Result < f64 , Self :: LogpError > {
403+ unimplemented ! ( )
404+ }
405+
406+ fn transformed_logp (
407+ & mut self ,
408+ _params : & Self :: TransformParams ,
409+ _untransformed_position : & [ f64 ] ,
410+ _untransformed_gradient : & mut [ f64 ] ,
411+ _transformed_position : & mut [ f64 ] ,
412+ _transformed_gradient : & mut [ f64 ] ,
413+ ) -> Result < ( f64 , f64 ) , Self :: LogpError > {
414+ unimplemented ! ( )
415+ }
416+
417+ fn update_transformation < ' a , R : rand:: Rng + ?Sized > (
418+ & ' a mut self ,
419+ _rng : & mut R ,
420+ _untransformed_positions : impl Iterator < Item = & ' a [ f64 ] > ,
421+ _untransformed_gradients : impl Iterator < Item = & ' a [ f64 ] > ,
422+ _params : & ' a mut Self :: TransformParams ,
423+ ) -> Result < ( ) , Self :: LogpError > {
424+ unimplemented ! ( )
425+ }
426+
427+ fn new_transformation (
428+ & mut self ,
429+ _untransformed_position : & [ f64 ] ,
430+ _untransfogmed_gradient : & [ f64 ] ,
431+ ) -> Result < Self :: TransformParams , Self :: LogpError > {
432+ unimplemented ! ( )
433+ }
434+
435+ fn transformation_id ( & self , _params : & Self :: TransformParams ) -> i64 {
436+ unimplemented ! ( )
437+ }
363438 }
364439}
365440
@@ -368,11 +443,8 @@ mod test {
368443 use super :: test_logps:: NormalLogp ;
369444 use super :: * ;
370445 use crate :: {
371- cpu_math:: CpuMath ,
372- mass_matrix:: DiagMassMatrix ,
373- nuts:: { AdaptStrategy , Chain , NutsChain , NutsOptions } ,
374- potential:: EuclideanPotential ,
375- DiagAdaptExpSettings ,
446+ chain:: NutsChain , cpu_math:: CpuMath , euclidean_hamiltonian:: EuclideanHamiltonian ,
447+ mass_matrix:: DiagMassMatrix , Chain , DiagAdaptExpSettings ,
376448 } ;
377449
378450 #[ test]
@@ -383,14 +455,14 @@ mod test {
383455 let func = NormalLogp :: new ( ndim, 3. ) ;
384456 let mut math = CpuMath :: new ( func) ;
385457 let num_tune = 100 ;
386- let options = AdaptOptions :: < DiagAdaptExpSettings > :: default ( ) ;
458+ let options = EuclideanAdaptOptions :: < DiagAdaptExpSettings > :: default ( ) ;
387459 let strategy = GlobalStrategy :: < _ , Strategy < _ > > :: new ( & mut math, options, num_tune) ;
388460
389461 let mass_matrix = DiagMassMatrix :: new ( & mut math, true ) ;
390462 let max_energy_error = 1000f64 ;
391463 let step_size = 0.1f64 ;
392464
393- let potential = EuclideanPotential :: new ( mass_matrix, max_energy_error, step_size) ;
465+ let hamiltonian = EuclideanHamiltonian :: new ( mass_matrix, max_energy_error, step_size) ;
394466 let options = NutsOptions {
395467 maxdepth : 10u64 ,
396468 store_gradient : true ,
@@ -405,7 +477,7 @@ mod test {
405477 } ;
406478 let chain = 0u64 ;
407479
408- let mut sampler = NutsChain :: new ( math, potential , strategy, options, rng, chain) ;
480+ let mut sampler = NutsChain :: new ( math, hamiltonian , strategy, options, rng, chain) ;
409481 sampler. set_position ( & vec ! [ 1.5f64 ; ndim] ) . unwrap ( ) ;
410482 for _ in 0 ..200 {
411483 sampler. draw ( ) . unwrap ( ) ;
0 commit comments