@@ -21,7 +21,7 @@ use crate::{
21
21
use crate :: nuts:: { SamplerStats , StatTraceBuilder } ;
22
22
23
23
pub struct GlobalStrategy < M : Math , A : MassMatrixAdaptStrategy < M > > {
24
- step_size : StepSizeStrategy < M , A > ,
24
+ step_size : StepSizeStrategy ,
25
25
mass_matrix : A ,
26
26
options : AdaptOptions < A :: Options > ,
27
27
num_tune : u64 ,
@@ -73,7 +73,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<
73
73
74
74
fn new_builder ( & self , settings : & impl Settings , dim : usize ) -> Self :: Builder {
75
75
CombinedStatsBuilder {
76
- stats1 : self . step_size . new_builder ( settings, dim) ,
76
+ stats1 : SamplerStats :: < M > :: new_builder ( & self . step_size , settings, dim) ,
77
77
stats2 : self . mass_matrix . new_builder ( settings, dim) ,
78
78
}
79
79
}
@@ -87,7 +87,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStats<M> for GlobalStrategy<M,
87
87
88
88
impl < M : Math , A : MassMatrixAdaptStrategy < M > > AdaptStrategy < M > for GlobalStrategy < M , A > {
89
89
type Potential = A :: Potential ;
90
- type Collector = CombinedCollector < M , AcceptanceRateCollector < M > , A :: Collector > ;
90
+ type Collector = CombinedCollector < M , AcceptanceRateCollector , A :: Collector > ;
91
91
type Options = AdaptOptions < A :: Options > ;
92
92
93
93
fn new ( math : & mut M , options : Self :: Options , num_tune : u64 ) -> Self {
@@ -99,7 +99,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
99
99
assert ! ( early_end < num_tune) ;
100
100
101
101
Self {
102
- step_size : StepSizeStrategy :: new ( math , options. dual_average_options , num_tune ) ,
102
+ step_size : StepSizeStrategy :: new ( options. dual_average_options ) ,
103
103
mass_matrix : A :: new ( math, options. mass_matrix_options , num_tune) ,
104
104
options,
105
105
num_tune,
@@ -121,7 +121,6 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
121
121
) {
122
122
self . mass_matrix . init ( math, options, potential, state, rng) ;
123
123
self . step_size . init ( math, options, potential, state, rng) ;
124
- self . step_size . enable ( ) ;
125
124
}
126
125
127
126
fn adapt < R : Rng + ?Sized > (
@@ -134,6 +133,8 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
134
133
state : & State < M > ,
135
134
rng : & mut R ,
136
135
) {
136
+ self . step_size . update ( & collector. collector1 ) ;
137
+
137
138
if draw >= self . num_tune {
138
139
self . tuning = false ;
139
140
return ;
@@ -172,44 +173,31 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
172
173
if did_change {
173
174
self . last_update = draw;
174
175
}
176
+
175
177
if is_late {
176
- self . step_size . use_mean_sym ( ) ;
178
+ self . step_size . update_estimator_late ( ) ;
179
+ } else {
180
+ self . step_size . update_estimator_early ( ) ;
177
181
}
182
+
178
183
// First time we change the mass matrix
179
184
if did_change & self . has_initial_mass_matrix {
180
185
self . has_initial_mass_matrix = false ;
181
186
self . step_size . init ( math, options, potential, state, rng) ;
182
187
} else {
183
- self . step_size . adapt (
184
- math,
185
- options,
186
- potential,
187
- draw,
188
- & collector. collector1 ,
189
- state,
190
- rng,
191
- ) ;
188
+ self . step_size . update_stepsize ( potential, false )
192
189
}
193
190
return ;
194
191
}
195
192
196
- if draw == self . num_tune - 1 {
197
- self . step_size . finalize ( ) ;
198
- }
199
- self . step_size . adapt (
200
- math,
201
- options,
202
- potential,
203
- draw,
204
- & collector. collector1 ,
205
- state,
206
- rng,
207
- ) ;
193
+ self . step_size . update_estimator_late ( ) ;
194
+ let is_last = draw == self . num_tune - 1 ;
195
+ self . step_size . update_stepsize ( potential, is_last) ;
208
196
}
209
197
210
198
fn new_collector ( & self , math : & mut M ) -> Self :: Collector {
211
199
CombinedCollector {
212
- collector1 : self . step_size . new_collector ( math ) ,
200
+ collector1 : self . step_size . new_collector ( ) ,
213
201
collector2 : self . mass_matrix . new_collector ( math) ,
214
202
_phantom : PhantomData ,
215
203
}
0 commit comments