Skip to content

Commit 064799e

Browse files
committed
feat: add untested walnuts implementation
1 parent 3aac0a5 commit 064799e

File tree

7 files changed

+185
-17
lines changed

7 files changed

+185
-17
lines changed

src/adapt_strategy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ mod test {
482482
store_unconstrained: true,
483483
check_turning: true,
484484
store_divergences: false,
485+
walnuts_options: None,
485486
};
486487

487488
let rng = {

src/euclidean_hamiltonian.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
225225
math: &mut M,
226226
start: &State<M, Self::Point>,
227227
dir: Direction,
228+
step_size_factor: f64,
228229
collector: &mut C,
229230
) -> LeapfrogResult<M, Self::Point> {
230231
let mut out = self.pool().new_state(math);
@@ -237,7 +238,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
237238
Direction::Backward => -1,
238239
};
239240

240-
let epsilon = (sign as f64) * self.step_size;
241+
let epsilon = (sign as f64) * self.step_size * step_size_factor;
241242

242243
start
243244
.point()

src/hamiltonian.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,36 @@ pub struct DivergenceInfo {
2828
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
2929
}
3030

31+
impl DivergenceInfo {
32+
pub fn new() -> Self {
33+
DivergenceInfo {
34+
start_momentum: None,
35+
start_location: None,
36+
start_gradient: None,
37+
end_location: None,
38+
energy_error: None,
39+
end_idx_in_trajectory: None,
40+
start_idx_in_trajectory: None,
41+
logp_function_error: None,
42+
}
43+
}
44+
}
45+
3146
#[derive(Debug, Copy, Clone)]
3247
pub enum Direction {
3348
Forward,
3449
Backward,
3550
}
3651

52+
impl Direction {
53+
pub fn reverse(&self) -> Self {
54+
match self {
55+
Direction::Forward => Direction::Backward,
56+
Direction::Backward => Direction::Forward,
57+
}
58+
}
59+
}
60+
3761
impl Distribution<Direction> for StandardUniform {
3862
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Direction {
3963
if rng.random::<bool>() {
@@ -82,6 +106,7 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
82106
math: &mut M,
83107
start: &State<M, Self::Point>,
84108
dir: Direction,
109+
step_size_factor: f64,
85110
collector: &mut C,
86111
) -> LeapfrogResult<M, Self::Point>;
87112

src/nuts.rs

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use serde::Serialize;
12
use thiserror::Error;
23

34
use std::{fmt::Debug, marker::PhantomData};
@@ -120,7 +121,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
120121
H: Hamiltonian<M>,
121122
R: rand::Rng + ?Sized,
122123
{
123-
let mut other = match self.single_step(math, hamiltonian, direction, collector) {
124+
let mut other = match self.single_step(math, hamiltonian, direction, options, collector) {
124125
Ok(Ok(tree)) => tree,
125126
Ok(Err(info)) => return ExtendResult::Diverging(self, info),
126127
Err(err) => return ExtendResult::Err(err),
@@ -213,19 +214,141 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
213214
math: &mut M,
214215
hamiltonian: &mut H,
215216
direction: Direction,
217+
options: &NutsOptions,
216218
collector: &mut C,
217219
) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
218220
let start = match direction {
219221
Direction::Forward => &self.right,
220222
Direction::Backward => &self.left,
221223
};
222-
let end = match hamiltonian.leapfrog(math, start, direction, collector) {
223-
LeapfrogResult::Divergence(info) => return Ok(Err(info)),
224-
LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
225-
LeapfrogResult::Ok(end) => end,
224+
225+
let (log_size, end) = match options.walnuts_options {
226+
Some(ref options) => {
227+
// Walnuts implementation
228+
// TODO: Shouldn't all be in this one big function...
229+
let mut step_size_factor = 1.0;
230+
let mut num_steps = 1;
231+
let mut current = start.clone();
232+
233+
let mut success = false;
234+
235+
'step_size_search: for _ in 0..options.max_step_size_halvings {
236+
current = start.clone();
237+
let mut min_energy = current.energy();
238+
let mut max_energy = min_energy;
239+
240+
for _ in 0..num_steps {
241+
current = match hamiltonian.leapfrog(
242+
math,
243+
&current,
244+
direction,
245+
step_size_factor,
246+
collector,
247+
) {
248+
LeapfrogResult::Ok(state) => state,
249+
LeapfrogResult::Divergence(_) => {
250+
num_steps *= 2;
251+
step_size_factor *= 0.5;
252+
continue 'step_size_search;
253+
}
254+
LeapfrogResult::Err(err) => {
255+
return Err(NutsError::LogpFailure(err.into()));
256+
}
257+
};
258+
259+
// Update min/max energies
260+
let current_energy = current.energy();
261+
min_energy = min_energy.min(current_energy);
262+
max_energy = max_energy.max(current_energy);
263+
}
264+
265+
if max_energy - min_energy > options.max_energy_error {
266+
num_steps *= 2;
267+
step_size_factor *= 0.5;
268+
continue 'step_size_search;
269+
}
270+
271+
success = true;
272+
break 'step_size_search;
273+
}
274+
275+
if !success {
276+
// TODO: More info
277+
return Ok(Err(DivergenceInfo::new()));
278+
}
279+
280+
// TODO
281+
let back = direction.reverse();
282+
let mut current_backward;
283+
284+
let mut reversible = true;
285+
286+
'rev_step_size: while num_steps >= 2 {
287+
num_steps /= 2;
288+
step_size_factor *= 0.5;
289+
290+
// TODO: Can we share code for the micro steps in the two directions?
291+
current_backward = current.clone();
292+
293+
let mut min_energy = current_backward.energy();
294+
let mut max_energy = min_energy;
295+
296+
for _ in 0..num_steps {
297+
current_backward = match hamiltonian.leapfrog(
298+
math,
299+
&current_backward,
300+
back,
301+
step_size_factor,
302+
collector,
303+
) {
304+
LeapfrogResult::Ok(state) => state,
305+
LeapfrogResult::Divergence(_) => {
306+
// We also reject in the backward direction, all is good so far...
307+
continue 'rev_step_size;
308+
}
309+
LeapfrogResult::Err(err) => {
310+
return Err(NutsError::LogpFailure(err.into()));
311+
}
312+
};
313+
314+
// Update min/max energies
315+
let current_energy = current_backward.energy();
316+
min_energy = min_energy.min(current_energy);
317+
max_energy = max_energy.max(current_energy);
318+
if max_energy - min_energy > options.max_energy_error {
319+
// We reject also in the backward direction, all good so far...
320+
continue 'rev_step_size;
321+
}
322+
}
323+
324+
// We did not reject in the backward direction, so we are not reversible
325+
reversible = false;
326+
break;
327+
}
328+
329+
if reversible {
330+
let log_size = -current.point().energy_error();
331+
(log_size, current)
332+
} else {
333+
// TODO: More info
334+
return Ok(Err(DivergenceInfo::new()));
335+
}
336+
}
337+
None => {
338+
// Classical NUTS
339+
//
340+
let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) {
341+
LeapfrogResult::Divergence(info) => return Ok(Err(info)),
342+
LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
343+
LeapfrogResult::Ok(end) => end,
344+
};
345+
346+
let log_size = -end.point().energy_error();
347+
348+
(log_size, end)
349+
}
226350
};
227351

228-
let log_size = -end.point().energy_error();
229352
Ok(Ok(NutsTree {
230353
right: end.clone(),
231354
left: end.clone(),
@@ -248,13 +371,22 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
248371
}
249372
}
250373

374+
#[derive(Debug, Clone, Copy, Serialize)]
375+
pub struct WalnutsOptions {
376+
pub max_energy_error: f64,
377+
pub max_step_size_halvings: u64,
378+
}
379+
380+
#[derive(Debug, Clone, Copy)]
251381
pub struct NutsOptions {
252382
pub maxdepth: u64,
253383
pub mindepth: u64,
254384
pub store_gradient: bool,
255385
pub store_unconstrained: bool,
256386
pub check_turning: bool,
257387
pub store_divergences: bool,
388+
389+
pub walnuts_options: Option<WalnutsOptions>,
258390
}
259391

260392
pub(crate) fn draw<M, H, R, C>(

src/sampler.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ use std::{
2020
};
2121

2222
use crate::{
23-
DiagAdaptExpSettings,
23+
DiagAdaptExpSettings, Model, SamplerStats,
2424
adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
2525
chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
2626
euclidean_hamiltonian::EuclideanHamiltonian,
27-
mass_matrix::DiagMassMatrix,
28-
mass_matrix::Strategy as DiagMassMatrixStrategy,
29-
mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
27+
mass_matrix::{
28+
DiagMassMatrix, LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings,
29+
Strategy as DiagMassMatrixStrategy,
30+
},
3031
math_base::Math,
31-
model::Model,
32-
nuts::NutsOptions,
33-
sampler_stats::{SamplerStats, StatsDims},
32+
nuts::{NutsOptions, WalnutsOptions},
33+
sampler_stats::StatsDims,
3434
storage::{ChainStorage, StorageConfig, TraceStorage},
3535
transform_adapt_strategy::{TransformAdaptation, TransformedSettings},
3636
transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions},
@@ -185,6 +185,7 @@ pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
185185

186186
pub num_chains: usize,
187187
pub seed: u64,
188+
pub walnuts_options: Option<WalnutsOptions>,
188189
}
189190

190191
pub type DiagGradNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
@@ -206,6 +207,7 @@ impl Default for DiagGradNutsSettings {
206207
check_turning: true,
207208
seed: 0,
208209
num_chains: 6,
210+
walnuts_options: None,
209211
}
210212
}
211213
}
@@ -225,6 +227,7 @@ impl Default for LowRankNutsSettings {
225227
check_turning: true,
226228
seed: 0,
227229
num_chains: 6,
230+
walnuts_options: None,
228231
};
229232
vals.adapt_options.mass_matrix_update_freq = 10;
230233
vals
@@ -246,6 +249,7 @@ impl Default for TransformedNutsSettings {
246249
check_turning: true,
247250
seed: 0,
248251
num_chains: 1,
252+
walnuts_options: None,
249253
}
250254
}
251255
}
@@ -278,6 +282,7 @@ impl Settings for LowRankNutsSettings {
278282
store_divergences: self.store_divergences,
279283
store_unconstrained: self.store_unconstrained,
280284
check_turning: self.check_turning,
285+
walnuts_options: self.walnuts_options,
281286
};
282287

283288
let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
@@ -346,6 +351,7 @@ impl Settings for DiagGradNutsSettings {
346351
store_divergences: self.store_divergences,
347352
store_unconstrained: self.store_unconstrained,
348353
check_turning: self.check_turning,
354+
walnuts_options: self.walnuts_options,
349355
};
350356

351357
let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
@@ -411,6 +417,7 @@ impl Settings for TransformedNutsSettings {
411417
store_divergences: self.store_divergences,
412418
store_unconstrained: self.store_unconstrained,
413419
check_turning: self.check_turning,
420+
walnuts_options: self.walnuts_options,
414421
};
415422

416423
let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");

src/stepsize/adapt.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ impl Strategy {
103103

104104
*hamiltonian.step_size_mut() = self.options.initial_step;
105105

106-
let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector);
106+
let state_next =
107+
hamiltonian.leapfrog(math, &state, Direction::Forward, 1.0, &mut collector);
107108

108109
let LeapfrogResult::Ok(_) = state_next else {
109110
return Ok(());
@@ -119,7 +120,7 @@ impl Strategy {
119120
for _ in 0..100 {
120121
let mut collector = AcceptanceRateCollector::new();
121122
collector.register_init(math, &state, options);
122-
let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector);
123+
let state_next = hamiltonian.leapfrog(math, &state, dir, 1.0, &mut collector);
123124
let LeapfrogResult::Ok(_) = state_next else {
124125
*hamiltonian.step_size_mut() = self.options.initial_step;
125126
return Ok(());

src/transformed_hamiltonian.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
303303
math: &mut M,
304304
start: &State<M, Self::Point>,
305305
dir: Direction,
306+
step_size_factor: f64,
306307
collector: &mut C,
307308
) -> LeapfrogResult<M, Self::Point> {
308309
let mut out = self.pool().new_state(math);
@@ -316,7 +317,7 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
316317
Direction::Backward => -1,
317318
};
318319

319-
let epsilon = (sign as f64) * self.step_size;
320+
let epsilon = (sign as f64) * self.step_size * step_size_factor;
320321

321322
start
322323
.point()

0 commit comments

Comments
 (0)