Skip to content

Commit 9918d38

Browse files
committed
fix: Fix bug where step size stats were not updated after tuning
1 parent 11bf46b commit 9918d38

File tree

3 files changed

+141
-194
lines changed

3 files changed

+141
-194
lines changed

src/adapt_strategy.rs

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::{
2121
use crate::nuts::{SamplerStats, StatTraceBuilder};
2222

2323
pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
24-
step_size: StepSizeStrategy<M, A>,
24+
step_size: StepSizeStrategy,
2525
mass_matrix: A,
2626
options: AdaptOptions<A::Options>,
2727
num_tune: u64,
@@ -73,7 +73,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<
7373

7474
fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder {
7575
CombinedStatsBuilder {
76-
stats1: self.step_size.new_builder(settings, dim),
76+
stats1: SamplerStats::<M>::new_builder(&self.step_size, settings, dim),
7777
stats2: self.mass_matrix.new_builder(settings, dim),
7878
}
7979
}
@@ -87,7 +87,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStats<M> for GlobalStrategy<M,
8787

8888
impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
8989
type Potential = A::Potential;
90-
type Collector = CombinedCollector<M, AcceptanceRateCollector<M>, A::Collector>;
90+
type Collector = CombinedCollector<M, AcceptanceRateCollector, A::Collector>;
9191
type Options = AdaptOptions<A::Options>;
9292

9393
fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self {
@@ -99,7 +99,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
9999
assert!(early_end < num_tune);
100100

101101
Self {
102-
step_size: StepSizeStrategy::new(math, options.dual_average_options, num_tune),
102+
step_size: StepSizeStrategy::new(options.dual_average_options),
103103
mass_matrix: A::new(math, options.mass_matrix_options, num_tune),
104104
options,
105105
num_tune,
@@ -121,7 +121,6 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
121121
) {
122122
self.mass_matrix.init(math, options, potential, state, rng);
123123
self.step_size.init(math, options, potential, state, rng);
124-
self.step_size.enable();
125124
}
126125

127126
fn adapt<R: Rng + ?Sized>(
@@ -134,6 +133,8 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
134133
state: &State<M>,
135134
rng: &mut R,
136135
) {
136+
self.step_size.update(&collector.collector1);
137+
137138
if draw >= self.num_tune {
138139
self.tuning = false;
139140
return;
@@ -172,44 +173,31 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
172173
if did_change {
173174
self.last_update = draw;
174175
}
176+
175177
if is_late {
176-
self.step_size.use_mean_sym();
178+
self.step_size.update_estimator_late();
179+
} else {
180+
self.step_size.update_estimator_early();
177181
}
182+
178183
// First time we change the mass matrix
179184
if did_change & self.has_initial_mass_matrix {
180185
self.has_initial_mass_matrix = false;
181186
self.step_size.init(math, options, potential, state, rng);
182187
} else {
183-
self.step_size.adapt(
184-
math,
185-
options,
186-
potential,
187-
draw,
188-
&collector.collector1,
189-
state,
190-
rng,
191-
);
188+
self.step_size.update_stepsize(potential, false)
192189
}
193190
return;
194191
}
195192

196-
if draw == self.num_tune - 1 {
197-
self.step_size.finalize();
198-
}
199-
self.step_size.adapt(
200-
math,
201-
options,
202-
potential,
203-
draw,
204-
&collector.collector1,
205-
state,
206-
rng,
207-
);
193+
self.step_size.update_estimator_late();
194+
let is_last = draw == self.num_tune - 1;
195+
self.step_size.update_stepsize(potential, is_last);
208196
}
209197

210198
fn new_collector(&self, math: &mut M) -> Self::Collector {
211199
CombinedCollector {
212-
collector1: self.step_size.new_collector(math),
200+
collector1: self.step_size.new_collector(),
213201
collector2: self.mass_matrix.new_collector(math),
214202
_phantom: PhantomData,
215203
}

src/stepsize.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::marker::PhantomData;
2-
31
use crate::{
42
math_base::Math,
53
nuts::{Collector, NutsOptions},
@@ -103,25 +101,23 @@ impl RunningMean {
103101
}
104102
}
105103

106-
pub struct AcceptanceRateCollector<M: Math> {
104+
pub struct AcceptanceRateCollector {
107105
initial_energy: f64,
108106
pub(crate) mean: RunningMean,
109107
pub(crate) mean_sym: RunningMean,
110-
phantom: PhantomData<M>,
111108
}
112109

113-
impl<M: Math> AcceptanceRateCollector<M> {
114-
pub(crate) fn new() -> AcceptanceRateCollector<M> {
110+
impl AcceptanceRateCollector {
111+
pub(crate) fn new() -> AcceptanceRateCollector {
115112
AcceptanceRateCollector {
116113
initial_energy: 0.,
117114
mean: RunningMean::new(),
118115
mean_sym: RunningMean::new(),
119-
phantom: PhantomData,
120116
}
121117
}
122118
}
123119

124-
impl<M: Math> Collector<M> for AcceptanceRateCollector<M> {
120+
impl<M: Math> Collector<M> for AcceptanceRateCollector {
125121
fn register_leapfrog(
126122
&mut self,
127123
_math: &mut M,

0 commit comments

Comments
 (0)