Skip to content

Commit 9a4ccb8

Browse files
committed
feat: add step size jitter
1 parent 4310bd0 commit 9a4ccb8

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

src/adapt_strategy.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
143143
self.step_size.update(&collector.collector1);
144144

145145
if draw >= self.num_tune {
146+
// Needed for step size jitter
147+
self.step_size.update_stepsize(rng, hamiltonian, true);
146148
self.tuning = false;
147149
return Ok(());
148150
}
@@ -194,14 +196,14 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
194196
self.step_size
195197
.init(math, options, hamiltonian, &position, rng)?;
196198
} else {
197-
self.step_size.update_stepsize(hamiltonian, false)
199+
self.step_size.update_stepsize(rng, hamiltonian, false)
198200
}
199201
return Ok(());
200202
}
201203

202204
self.step_size.update_estimator_late();
203205
let is_last = draw == self.num_tune - 1;
204-
self.step_size.update_stepsize(hamiltonian, is_last);
206+
self.step_size.update_stepsize(rng, hamiltonian, is_last);
205207
Ok(())
206208
}
207209

src/stepsize_adapt.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,27 @@ impl Strategy {
120120
.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
121121
}
122122

123-
pub fn update_stepsize<M: Math>(
123+
pub fn update_stepsize<M: Math, R: Rng + ?Sized>(
124124
&mut self,
125+
rng: &mut R,
125126
hamiltonian: &mut impl Hamiltonian<M>,
126127
use_best_guess: bool,
127128
) {
128-
if let Some(step_size) = self.options.fixed_step_size {
129-
*hamiltonian.step_size_mut() = step_size;
130-
return;
131-
}
132-
if use_best_guess {
133-
*hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size_adapted();
129+
let step_size = if let Some(step_size) = self.options.fixed_step_size {
130+
step_size
131+
} else if use_best_guess {
132+
self.step_size_adapt.current_step_size_adapted()
134133
} else {
135-
*hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size();
134+
self.step_size_adapt.current_step_size()
135+
};
136+
137+
if let Some(jitter) = self.options.jitter {
138+
let jitter =
139+
rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter"));
140+
let jittered_step_size = step_size * jitter;
141+
*hamiltonian.step_size_mut() = jittered_step_size;
142+
} else {
143+
*hamiltonian.step_size_mut() = step_size;
136144
}
137145
}
138146

@@ -235,6 +243,7 @@ pub struct DualAverageSettings {
235243
pub initial_step: f64,
236244
pub params: DualAverageOptions,
237245
pub fixed_step_size: Option<f64>,
246+
pub jitter: Option<f64>,
238247
}
239248

240249
impl Default for DualAverageSettings {
@@ -244,6 +253,7 @@ impl Default for DualAverageSettings {
244253
initial_step: 0.1,
245254
params: DualAverageOptions::default(),
246255
fixed_step_size: None,
256+
jitter: None,
247257
}
248258
}
249259
}

src/transform_adapt_strategy.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ impl<M: Math> AdaptStrategy<M> for TransformAdaptation {
212212
self.step_size.update(&collector.collector1);
213213

214214
if draw >= self.num_tune {
215+
// Needed for step size jitter
216+
self.step_size.update_stepsize(rng, hamiltonian, true);
215217
self.tuning = false;
216218
return Ok(());
217219
}
@@ -237,13 +239,13 @@ impl<M: Math> AdaptStrategy<M> for TransformAdaptation {
237239
)?;
238240
}
239241
self.step_size.update_estimator_early();
240-
self.step_size.update_stepsize(hamiltonian, false);
242+
self.step_size.update_stepsize(rng, hamiltonian, false);
241243
return Ok(());
242244
}
243245

244246
self.step_size.update_estimator_late();
245247
let is_last = draw == self.num_tune - 1;
246-
self.step_size.update_stepsize(hamiltonian, is_last);
248+
self.step_size.update_stepsize(rng, hamiltonian, is_last);
247249
Ok(())
248250
}
249251

0 commit comments

Comments
 (0)