Skip to content

Commit 16d2235

Browse files
committed
feat: add step size jitter
1 parent 3648e00 commit 16d2235

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

src/adapt_strategy.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
143143
self.step_size.update(&collector.collector1);
144144

145145
if draw >= self.num_tune {
146+
// Needed for step size jitter
147+
self.step_size.update_stepsize(rng, hamiltonian, true);
146148
self.tuning = false;
147149
return Ok(());
148150
}
@@ -194,14 +196,14 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
194196
self.step_size
195197
.init(math, options, hamiltonian, &position, rng)?;
196198
} else {
197-
self.step_size.update_stepsize(hamiltonian, false)
199+
self.step_size.update_stepsize(rng, hamiltonian, false)
198200
}
199201
return Ok(());
200202
}
201203

202204
self.step_size.update_estimator_late();
203205
let is_last = draw == self.num_tune - 1;
204-
self.step_size.update_stepsize(hamiltonian, is_last);
206+
self.step_size.update_stepsize(rng, hamiltonian, is_last);
205207
Ok(())
206208
}
207209

src/stepsize_adapt.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,19 +121,27 @@ impl Strategy {
121121
.advance(self.last_sym_mean_tree_accept, self.options.target_accept);
122122
}
123123

124-
pub fn update_stepsize<M: Math>(
124+
pub fn update_stepsize<M: Math, R: Rng + ?Sized>(
125125
&mut self,
126+
rng: &mut R,
126127
hamiltonian: &mut impl Hamiltonian<M>,
127128
use_best_guess: bool,
128129
) {
129-
if let Some(step_size) = self.options.fixed_step_size {
130-
*hamiltonian.step_size_mut() = step_size;
131-
return;
132-
}
133-
if use_best_guess {
134-
*hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size_adapted();
130+
let step_size = if let Some(step_size) = self.options.fixed_step_size {
131+
step_size
132+
} else if use_best_guess {
133+
self.step_size_adapt.current_step_size_adapted()
135134
} else {
136-
*hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size();
135+
self.step_size_adapt.current_step_size()
136+
};
137+
138+
if let Some(jitter) = self.options.jitter {
139+
let jitter =
140+
rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter"));
141+
let jittered_step_size = step_size * jitter;
142+
*hamiltonian.step_size_mut() = jittered_step_size;
143+
} else {
144+
*hamiltonian.step_size_mut() = step_size;
137145
}
138146
}
139147

@@ -236,6 +244,7 @@ pub struct DualAverageSettings {
236244
pub initial_step: f64,
237245
pub params: DualAverageOptions,
238246
pub fixed_step_size: Option<f64>,
247+
pub jitter: Option<f64>,
239248
}
240249

241250
impl Default for DualAverageSettings {
@@ -245,6 +254,7 @@ impl Default for DualAverageSettings {
245254
initial_step: 0.1,
246255
params: DualAverageOptions::default(),
247256
fixed_step_size: None,
257+
jitter: None,
248258
}
249259
}
250260
}

src/transform_adapt_strategy.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ impl<M: Math> AdaptStrategy<M> for TransformAdaptation {
213213
self.step_size.update(&collector.collector1);
214214

215215
if draw >= self.num_tune {
216+
// Needed for step size jitter
217+
self.step_size.update_stepsize(rng, hamiltonian, true);
216218
self.tuning = false;
217219
return Ok(());
218220
}
@@ -238,13 +240,13 @@ impl<M: Math> AdaptStrategy<M> for TransformAdaptation {
238240
)?;
239241
}
240242
self.step_size.update_estimator_early();
241-
self.step_size.update_stepsize(hamiltonian, false);
243+
self.step_size.update_stepsize(rng, hamiltonian, false);
242244
return Ok(());
243245
}
244246

245247
self.step_size.update_estimator_late();
246248
let is_last = draw == self.num_tune - 1;
247-
self.step_size.update_stepsize(hamiltonian, is_last);
249+
self.step_size.update_stepsize(rng, hamiltonian, is_last);
248250
Ok(())
249251
}
250252

0 commit comments

Comments
 (0)