@@ -3,13 +3,11 @@ use std::marker::PhantomData;
3
3
use rand:: Rng ;
4
4
5
5
use crate :: {
6
- chain:: AdaptStrategy ,
7
- euclidean_hamiltonian:: { EuclideanHamiltonian , EuclideanPoint } ,
8
- hamiltonian:: { Hamiltonian , Point } ,
6
+ euclidean_hamiltonian:: EuclideanPoint ,
7
+ hamiltonian:: Point ,
9
8
mass_matrix:: { DiagMassMatrix , DrawGradCollector , MassMatrix , RunningVariance } ,
10
- nuts:: NutsOptions ,
9
+ nuts:: { Collector , NutsOptions } ,
11
10
sampler_stats:: SamplerStats ,
12
- state:: State ,
13
11
Math , NutsError , Settings ,
14
12
} ;
15
13
const LOWER_LIMIT : f64 = 1e-20f64 ;
@@ -43,8 +41,10 @@ pub struct Strategy<M: Math> {
43
41
_phantom : PhantomData < M > ,
44
42
}
45
43
46
- pub trait MassMatrixAdaptStrategy < M : Math > : AdaptStrategy < M > {
44
+ pub trait MassMatrixAdaptStrategy < M : Math > : SamplerStats < M > {
47
45
type MassMatrix : MassMatrix < M > ;
46
+ type Collector : Collector < M , EuclideanPoint < M > > ;
47
+ type Options : std:: fmt:: Debug + Default + Clone + Send + Sync + Copy ;
48
48
49
49
fn update_estimators ( & mut self , math : & mut M , collector : & Self :: Collector ) ;
50
50
@@ -55,11 +55,26 @@ pub trait MassMatrixAdaptStrategy<M: Math>: AdaptStrategy<M> {
55
55
fn background_count ( & self ) -> u64 ;
56
56
57
57
/// Give the opportunity to update the potential and return if it was changed
58
- fn update_potential ( & self , math : & mut M , potential : & mut Self :: Hamiltonian ) -> bool ;
58
+ fn adapt ( & self , math : & mut M , mass_matrix : & mut Self :: MassMatrix ) -> bool ;
59
+
60
+ fn new ( math : & mut M , options : Self :: Options , _num_tune : u64 , _chain : u64 ) -> Self ;
61
+
62
+ fn init < R : Rng + ?Sized > (
63
+ & mut self ,
64
+ math : & mut M ,
65
+ _options : & mut NutsOptions ,
66
+ mass_matrix : & mut Self :: MassMatrix ,
67
+ point : & impl Point < M > ,
68
+ _rng : & mut R ,
69
+ ) -> Result < ( ) , NutsError > ;
70
+
71
+ fn new_collector ( & self , math : & mut M ) -> Self :: Collector ;
59
72
}
60
73
61
74
impl < M : Math > MassMatrixAdaptStrategy < M > for Strategy < M > {
62
75
type MassMatrix = DiagMassMatrix < M > ;
76
+ type Collector = DrawGradCollector < M > ;
77
+ type Options = DiagAdaptExpSettings ;
63
78
64
79
fn update_estimators ( & mut self , math : & mut M , collector : & DrawGradCollector < M > ) {
65
80
if collector. is_good {
@@ -88,11 +103,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
88
103
}
89
104
90
105
/// Give the opportunity to update the potential and return if it was changed
91
- fn update_potential (
92
- & self ,
93
- math : & mut M ,
94
- potential : & mut EuclideanHamiltonian < M , Self :: MassMatrix > ,
95
- ) -> bool {
106
+ fn adapt ( & self , math : & mut M , mass_matrix : & mut DiagMassMatrix < M > ) -> bool {
96
107
if self . current_count ( ) < 3 {
97
108
return false ;
98
109
}
@@ -102,7 +113,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
102
113
assert ! ( draw_scale == grad_scale) ;
103
114
104
115
if self . _settings . use_grad_based_estimate {
105
- potential . mass_matrix . update_diag_draw_grad (
116
+ mass_matrix. update_diag_draw_grad (
106
117
math,
107
118
draw_var,
108
119
grad_var,
@@ -111,35 +122,11 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
111
122
) ;
112
123
} else {
113
124
let scale = ( self . exp_variance_draw . count ( ) as f64 ) . recip ( ) ;
114
- potential. mass_matrix . update_diag_draw (
115
- math,
116
- draw_var,
117
- scale,
118
- None ,
119
- ( LOWER_LIMIT , UPPER_LIMIT ) ,
120
- ) ;
125
+ mass_matrix. update_diag_draw ( math, draw_var, scale, None , ( LOWER_LIMIT , UPPER_LIMIT ) ) ;
121
126
}
122
127
123
128
true
124
129
}
125
- }
126
-
127
- pub type Stats = ( ) ;
128
- pub type StatsBuilder = ( ) ;
129
-
130
- impl < M : Math > SamplerStats < M > for Strategy < M > {
131
- type Builder = Stats ;
132
- type Stats = StatsBuilder ;
133
-
134
- fn new_builder ( & self , _settings : & impl Settings , _dim : usize ) -> Self :: Builder { }
135
-
136
- fn current_stats ( & self , _math : & mut M ) -> Self :: Stats { }
137
- }
138
-
139
- impl < M : Math > AdaptStrategy < M > for Strategy < M > {
140
- type Hamiltonian = EuclideanHamiltonian < M , DiagMassMatrix < M > > ;
141
- type Collector = DrawGradCollector < M > ;
142
- type Options = DiagAdaptExpSettings ;
143
130
144
131
fn new ( math : & mut M , options : Self :: Options , _num_tune : u64 , _chain : u64 ) -> Self {
145
132
Self {
@@ -156,49 +143,37 @@ impl<M: Math> AdaptStrategy<M> for Strategy<M> {
156
143
& mut self ,
157
144
math : & mut M ,
158
145
_options : & mut NutsOptions ,
159
- hamiltonian : & mut Self :: Hamiltonian ,
160
- position : & [ f64 ] ,
146
+ mass_matrix : & mut Self :: MassMatrix ,
147
+ point : & impl Point < M > ,
161
148
_rng : & mut R ,
162
149
) -> Result < ( ) , NutsError > {
163
- let state = hamiltonian. init_state ( math, position) ?;
164
-
165
- self . exp_variance_draw
166
- . add_sample ( math, state. point ( ) . position ( ) ) ;
167
- self . exp_variance_draw_bg
168
- . add_sample ( math, state. point ( ) . position ( ) ) ;
169
- self . exp_variance_grad
170
- . add_sample ( math, state. point ( ) . gradient ( ) ) ;
171
- self . exp_variance_grad_bg
172
- . add_sample ( math, state. point ( ) . gradient ( ) ) ;
173
-
174
- hamiltonian. mass_matrix . update_diag_grad (
150
+ self . exp_variance_draw . add_sample ( math, point. position ( ) ) ;
151
+ self . exp_variance_draw_bg . add_sample ( math, point. position ( ) ) ;
152
+ self . exp_variance_grad . add_sample ( math, point. gradient ( ) ) ;
153
+ self . exp_variance_grad_bg . add_sample ( math, point. gradient ( ) ) ;
154
+
155
+ mass_matrix. update_diag_grad (
175
156
math,
176
- state . point ( ) . gradient ( ) ,
157
+ point. gradient ( ) ,
177
158
1f64 ,
178
159
( INIT_LOWER_LIMIT , INIT_UPPER_LIMIT ) ,
179
160
) ;
180
161
Ok ( ( ) )
181
162
}
182
163
183
- fn adapt < R : Rng + ?Sized > (
184
- & mut self ,
185
- _math : & mut M ,
186
- _options : & mut NutsOptions ,
187
- _potential : & mut Self :: Hamiltonian ,
188
- _draw : u64 ,
189
- _collector : & Self :: Collector ,
190
- _state : & State < M , EuclideanPoint < M > > ,
191
- _rng : & mut R ,
192
- ) -> Result < ( ) , NutsError > {
193
- // Must be controlled from a different meta strategy
194
- Ok ( ( ) )
195
- }
196
-
197
164
fn new_collector ( & self , math : & mut M ) -> Self :: Collector {
198
165
DrawGradCollector :: new ( math)
199
166
}
167
+ }
200
168
201
- fn is_tuning ( & self ) -> bool {
202
- unreachable ! ( )
203
- }
169
+ pub type Stats = ( ) ;
170
+ pub type StatsBuilder = ( ) ;
171
+
172
+ impl < M : Math > SamplerStats < M > for Strategy < M > {
173
+ type Builder = Stats ;
174
+ type Stats = StatsBuilder ;
175
+
176
+ fn new_builder ( & self , _settings : & impl Settings , _dim : usize ) -> Self :: Builder { }
177
+
178
+ fn current_stats ( & self , _math : & mut M ) -> Self :: Stats { }
204
179
}
0 commit comments