@@ -5,25 +5,26 @@ use itertools::Itertools;
5
5
use rand:: Rng ;
6
6
7
7
use crate :: {
8
+ chain:: AdaptStrategy ,
9
+ hamiltonian:: { DivergenceInfo , Hamiltonian , Point } ,
8
10
mass_matrix_adapt:: MassMatrixAdaptStrategy ,
9
11
math_base:: Math ,
10
- nuts:: { AdaptStats , AdaptStrategy , Collector , NutsOptions } ,
12
+ nuts:: { Collector , NutsOptions } ,
11
13
sampler:: Settings ,
14
+ sampler_stats:: { SamplerStats , StatTraceBuilder } ,
12
15
state:: State ,
13
16
stepsize:: AcceptanceRateCollector ,
14
17
stepsize_adapt:: {
15
18
DualAverageSettings , Stats as StepSizeStats , StatsBuilder as StepSizeStatsBuilder ,
16
19
Strategy as StepSizeStrategy ,
17
20
} ,
18
- DivergenceInfo ,
21
+ NutsError ,
19
22
} ;
20
23
21
- use crate :: nuts:: { SamplerStats , StatTraceBuilder } ;
22
-
23
24
pub struct GlobalStrategy < M : Math , A : MassMatrixAdaptStrategy < M > > {
24
25
step_size : StepSizeStrategy ,
25
26
mass_matrix : A ,
26
- options : AdaptOptions < A :: Options > ,
27
+ options : EuclideanAdaptOptions < A :: Options > ,
27
28
num_tune : u64 ,
28
29
// The number of draws in the the early window
29
30
early_end : u64 ,
@@ -36,7 +37,7 @@ pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
36
37
}
37
38
38
39
#[ derive( Debug , Clone , Copy ) ]
39
- pub struct AdaptOptions < S : Debug + Default > {
40
+ pub struct EuclideanAdaptOptions < S : Debug + Default > {
40
41
pub dual_average_options : DualAverageSettings ,
41
42
pub mass_matrix_options : S ,
42
43
pub early_window : f64 ,
@@ -46,7 +47,7 @@ pub struct AdaptOptions<S: Debug + Default> {
46
47
pub mass_matrix_update_freq : u64 ,
47
48
}
48
49
49
- impl < S : Debug + Default > Default for AdaptOptions < S > {
50
+ impl < S : Debug + Default > Default for EuclideanAdaptOptions < S > {
50
51
fn default ( ) -> Self {
51
52
Self {
52
53
dual_average_options : DualAverageSettings :: default ( ) ,
@@ -79,16 +80,15 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<
79
80
}
80
81
}
81
82
82
- impl < M : Math , A : MassMatrixAdaptStrategy < M > > AdaptStats < M > for GlobalStrategy < M , A > {
83
- fn num_grad_evals ( stats : & Self :: Stats ) -> usize {
84
- stats. stats1 . n_steps as usize
85
- }
86
- }
87
-
88
83
impl < M : Math , A : MassMatrixAdaptStrategy < M > > AdaptStrategy < M > for GlobalStrategy < M , A > {
89
- type Potential = A :: Potential ;
90
- type Collector = CombinedCollector < M , AcceptanceRateCollector , A :: Collector > ;
91
- type Options = AdaptOptions < A :: Options > ;
84
+ type Hamiltonian = A :: Hamiltonian ;
85
+ type Collector = CombinedCollector <
86
+ M ,
87
+ <Self :: Hamiltonian as Hamiltonian < M > >:: Point ,
88
+ AcceptanceRateCollector ,
89
+ A :: Collector ,
90
+ > ;
91
+ type Options = EuclideanAdaptOptions < A :: Options > ;
92
92
93
93
fn new ( math : & mut M , options : Self :: Options , num_tune : u64 ) -> Self {
94
94
let num_tune_f = num_tune as f64 ;
@@ -115,29 +115,32 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
115
115
& mut self ,
116
116
math : & mut M ,
117
117
options : & mut NutsOptions ,
118
- potential : & mut Self :: Potential ,
119
- state : & State < M > ,
118
+ hamiltonian : & mut Self :: Hamiltonian ,
119
+ state : & State < M , < Self :: Hamiltonian as Hamiltonian < M > > :: Point > ,
120
120
rng : & mut R ,
121
- ) {
122
- self . mass_matrix . init ( math, options, potential, state, rng) ;
123
- self . step_size . init ( math, options, potential, state, rng) ;
121
+ ) -> Result < ( ) , NutsError > {
122
+ self . mass_matrix
123
+ . init ( math, options, hamiltonian, state, rng) ?;
124
+ self . step_size
125
+ . init ( math, options, hamiltonian, state, rng) ?;
126
+ Ok ( ( ) )
124
127
}
125
128
126
129
fn adapt < R : Rng + ?Sized > (
127
130
& mut self ,
128
131
math : & mut M ,
129
132
options : & mut NutsOptions ,
130
- potential : & mut Self :: Potential ,
133
+ hamiltonian : & mut Self :: Hamiltonian ,
131
134
draw : u64 ,
132
135
collector : & Self :: Collector ,
133
- state : & State < M > ,
136
+ state : & State < M , < Self :: Hamiltonian as Hamiltonian < M > > :: Point > ,
134
137
rng : & mut R ,
135
- ) {
138
+ ) -> Result < ( ) , NutsError > {
136
139
self . step_size . update ( & collector. collector1 ) ;
137
140
138
141
if draw >= self . num_tune {
139
142
self . tuning = false ;
140
- return ;
143
+ return Ok ( ( ) ) ;
141
144
}
142
145
143
146
if draw < self . final_step_size_window {
@@ -165,7 +168,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
165
168
let did_change = if force_update
166
169
| ( draw - self . last_update >= self . options . mass_matrix_update_freq )
167
170
{
168
- self . mass_matrix . update_potential ( math, potential )
171
+ self . mass_matrix . update_potential ( math, hamiltonian )
169
172
} else {
170
173
false
171
174
} ;
@@ -183,24 +186,25 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
183
186
// First time we change the mass matrix
184
187
if did_change & self . has_initial_mass_matrix {
185
188
self . has_initial_mass_matrix = false ;
186
- self . step_size . init ( math, options, potential, state, rng) ;
189
+ self . step_size
190
+ . init ( math, options, hamiltonian, state, rng) ?;
187
191
} else {
188
- self . step_size . update_stepsize ( potential , false )
192
+ self . step_size . update_stepsize ( hamiltonian , false )
189
193
}
190
- return ;
194
+ return Ok ( ( ) ) ;
191
195
}
192
196
193
197
self . step_size . update_estimator_late ( ) ;
194
198
let is_last = draw == self . num_tune - 1 ;
195
- self . step_size . update_stepsize ( potential, is_last) ;
199
+ self . step_size . update_stepsize ( hamiltonian, is_last) ;
200
+ Ok ( ( ) )
196
201
}
197
202
198
203
fn new_collector ( & self , math : & mut M ) -> Self :: Collector {
199
- CombinedCollector {
200
- collector1 : self . step_size . new_collector ( ) ,
201
- collector2 : self . mass_matrix . new_collector ( math) ,
202
- _phantom : PhantomData ,
203
- }
204
+ Self :: Collector :: new (
205
+ self . step_size . new_collector ( ) ,
206
+ self . mass_matrix . new_collector ( math) ,
207
+ )
204
208
}
205
209
206
210
fn is_tuning ( & self ) -> bool {
@@ -277,22 +281,48 @@ where
277
281
}
278
282
}
279
283
280
- pub struct CombinedCollector < M : Math , C1 : Collector < M > , C2 : Collector < M > > {
281
- collector1 : C1 ,
282
- collector2 : C2 ,
284
+ pub struct CombinedCollector < M , P , C1 , C2 >
285
+ where
286
+ M : Math ,
287
+ P : Point < M > ,
288
+ C1 : Collector < M , P > ,
289
+ C2 : Collector < M , P > ,
290
+ {
291
+ pub collector1 : C1 ,
292
+ pub collector2 : C2 ,
283
293
_phantom : PhantomData < M > ,
294
+ _phantom2 : PhantomData < P > ,
284
295
}
285
296
286
- impl < M : Math , C1 , C2 > Collector < M > for CombinedCollector < M , C1 , C2 >
297
+ impl < M , P , C1 , C2 > CombinedCollector < M , P , C1 , C2 >
287
298
where
288
- C1 : Collector < M > ,
289
- C2 : Collector < M > ,
299
+ M : Math ,
300
+ P : Point < M > ,
301
+ C1 : Collector < M , P > ,
302
+ C2 : Collector < M , P > ,
303
+ {
304
+ pub fn new ( collector1 : C1 , collector2 : C2 ) -> Self {
305
+ CombinedCollector {
306
+ collector1,
307
+ collector2,
308
+ _phantom : PhantomData ,
309
+ _phantom2 : PhantomData ,
310
+ }
311
+ }
312
+ }
313
+
314
+ impl < M , P , C1 , C2 > Collector < M , P > for CombinedCollector < M , P , C1 , C2 >
315
+ where
316
+ M : Math ,
317
+ P : Point < M > ,
318
+ C1 : Collector < M , P > ,
319
+ C2 : Collector < M , P > ,
290
320
{
291
321
fn register_leapfrog (
292
322
& mut self ,
293
323
math : & mut M ,
294
- start : & State < M > ,
295
- end : & State < M > ,
324
+ start : & State < M , P > ,
325
+ end : & State < M , P > ,
296
326
divergence_info : Option < & DivergenceInfo > ,
297
327
) {
298
328
self . collector1
@@ -301,15 +331,15 @@ where
301
331
. register_leapfrog ( math, start, end, divergence_info) ;
302
332
}
303
333
304
- fn register_draw ( & mut self , math : & mut M , state : & State < M > , info : & crate :: nuts:: SampleInfo ) {
334
+ fn register_draw ( & mut self , math : & mut M , state : & State < M , P > , info : & crate :: nuts:: SampleInfo ) {
305
335
self . collector1 . register_draw ( math, state, info) ;
306
336
self . collector2 . register_draw ( math, state, info) ;
307
337
}
308
338
309
339
fn register_init (
310
340
& mut self ,
311
341
math : & mut M ,
312
- state : & State < M > ,
342
+ state : & State < M , P > ,
313
343
options : & crate :: nuts:: NutsOptions ,
314
344
) {
315
345
self . collector1 . register_init ( math, state, options) ;
@@ -319,7 +349,7 @@ where
319
349
320
350
#[ cfg( test) ]
321
351
pub mod test_logps {
322
- use crate :: { cpu_math:: CpuLogpFunc , nuts :: LogpError } ;
352
+ use crate :: { cpu_math:: CpuLogpFunc , math_base :: LogpError } ;
323
353
use thiserror:: Error ;
324
354
325
355
#[ derive( Clone , Debug ) ]
@@ -344,6 +374,7 @@ pub mod test_logps {
344
374
345
375
impl CpuLogpFunc for NormalLogp {
346
376
type LogpError = NormalLogpError ;
377
+ type TransformParams = ( ) ;
347
378
348
379
fn dim ( & self ) -> usize {
349
380
self . dim
@@ -360,6 +391,50 @@ pub mod test_logps {
360
391
}
361
392
Ok ( logp)
362
393
}
394
+
395
+ fn inv_transform_normalize (
396
+ & mut self ,
397
+ _params : & Self :: TransformParams ,
398
+ _untransformed_position : & [ f64 ] ,
399
+ _untransofrmed_gradient : & [ f64 ] ,
400
+ _transformed_position : & mut [ f64 ] ,
401
+ _transformed_gradient : & mut [ f64 ] ,
402
+ ) -> Result < f64 , Self :: LogpError > {
403
+ unimplemented ! ( )
404
+ }
405
+
406
+ fn transformed_logp (
407
+ & mut self ,
408
+ _params : & Self :: TransformParams ,
409
+ _untransformed_position : & [ f64 ] ,
410
+ _untransformed_gradient : & mut [ f64 ] ,
411
+ _transformed_position : & mut [ f64 ] ,
412
+ _transformed_gradient : & mut [ f64 ] ,
413
+ ) -> Result < ( f64 , f64 ) , Self :: LogpError > {
414
+ unimplemented ! ( )
415
+ }
416
+
417
+ fn update_transformation < ' a , R : rand:: Rng + ?Sized > (
418
+ & ' a mut self ,
419
+ _rng : & mut R ,
420
+ _untransformed_positions : impl Iterator < Item = & ' a [ f64 ] > ,
421
+ _untransformed_gradients : impl Iterator < Item = & ' a [ f64 ] > ,
422
+ _params : & ' a mut Self :: TransformParams ,
423
+ ) -> Result < ( ) , Self :: LogpError > {
424
+ unimplemented ! ( )
425
+ }
426
+
427
+ fn new_transformation (
428
+ & mut self ,
429
+ _untransformed_position : & [ f64 ] ,
430
+ _untransfogmed_gradient : & [ f64 ] ,
431
+ ) -> Result < Self :: TransformParams , Self :: LogpError > {
432
+ unimplemented ! ( )
433
+ }
434
+
435
+ fn transformation_id ( & self , _params : & Self :: TransformParams ) -> i64 {
436
+ unimplemented ! ( )
437
+ }
363
438
}
364
439
}
365
440
@@ -368,11 +443,8 @@ mod test {
368
443
use super :: test_logps:: NormalLogp ;
369
444
use super :: * ;
370
445
use crate :: {
371
- cpu_math:: CpuMath ,
372
- mass_matrix:: DiagMassMatrix ,
373
- nuts:: { AdaptStrategy , Chain , NutsChain , NutsOptions } ,
374
- potential:: EuclideanPotential ,
375
- DiagAdaptExpSettings ,
446
+ chain:: NutsChain , cpu_math:: CpuMath , euclidean_hamiltonian:: EuclideanHamiltonian ,
447
+ mass_matrix:: DiagMassMatrix , Chain , DiagAdaptExpSettings ,
376
448
} ;
377
449
378
450
#[ test]
@@ -383,14 +455,14 @@ mod test {
383
455
let func = NormalLogp :: new ( ndim, 3. ) ;
384
456
let mut math = CpuMath :: new ( func) ;
385
457
let num_tune = 100 ;
386
- let options = AdaptOptions :: < DiagAdaptExpSettings > :: default ( ) ;
458
+ let options = EuclideanAdaptOptions :: < DiagAdaptExpSettings > :: default ( ) ;
387
459
let strategy = GlobalStrategy :: < _ , Strategy < _ > > :: new ( & mut math, options, num_tune) ;
388
460
389
461
let mass_matrix = DiagMassMatrix :: new ( & mut math, true ) ;
390
462
let max_energy_error = 1000f64 ;
391
463
let step_size = 0.1f64 ;
392
464
393
- let potential = EuclideanPotential :: new ( mass_matrix, max_energy_error, step_size) ;
465
+ let hamiltonian = EuclideanHamiltonian :: new ( mass_matrix, max_energy_error, step_size) ;
394
466
let options = NutsOptions {
395
467
maxdepth : 10u64 ,
396
468
store_gradient : true ,
@@ -405,7 +477,7 @@ mod test {
405
477
} ;
406
478
let chain = 0u64 ;
407
479
408
- let mut sampler = NutsChain :: new ( math, potential , strategy, options, rng, chain) ;
480
+ let mut sampler = NutsChain :: new ( math, hamiltonian , strategy, options, rng, chain) ;
409
481
sampler. set_position ( & vec ! [ 1.5f64 ; ndim] ) . unwrap ( ) ;
410
482
for _ in 0 ..200 {
411
483
sampler. draw ( ) . unwrap ( ) ;
0 commit comments