Skip to content

Commit 3648e00

Browse files
committed
feat: clean up walnuts a little bit
1 parent 10bd9b9 commit 3648e00

10 files changed

+248
-155
lines changed

src/adapt_strategy.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,12 @@ where
339339
start: &State<M, P>,
340340
end: &State<M, P>,
341341
divergence_info: Option<&DivergenceInfo>,
342+
num_substeps: u64,
342343
) {
343344
self.collector1
344-
.register_leapfrog(math, start, end, divergence_info);
345+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
345346
self.collector2
346-
.register_leapfrog(math, start, end, divergence_info);
347+
.register_leapfrog(math, start, end, divergence_info, num_substeps);
347348
}
348349

349350
fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {

src/chain.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ where
183183
&mut self.hamiltonian,
184184
&self.options,
185185
&mut self.collector,
186+
self.draw_count < 70,
186187
)?;
187188
let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
188189
state.write_position(math, &mut position);
@@ -237,6 +238,7 @@ pub struct NutsStatsBuilder<M: Math, A: AdaptStrategy<M>> {
237238
divergence_start_grad: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
238239
divergence_end: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
239240
divergence_momentum: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
241+
non_reversible: Option<BooleanBuilder>,
240242
divergence_msg: Option<StringBuilder>,
241243
}
242244

@@ -274,7 +276,9 @@ impl<M: Math, A: AdaptStrategy<M>> NutsStatsBuilder<M, A> {
274276
None
275277
};
276278

277-
let (div_start, div_start_grad, div_end, div_mom, div_msg) = if options.store_divergences {
279+
let (div_start, div_start_grad, div_end, div_mom, non_rev, div_msg) = if options
280+
.store_divergences
281+
{
278282
let start_location_prim = PrimitiveBuilder::new();
279283
let start_location_list = FixedSizeListBuilder::new(start_location_prim, dim as i32);
280284

@@ -288,17 +292,20 @@ impl<M: Math, A: AdaptStrategy<M>> NutsStatsBuilder<M, A> {
288292
let momentum_location_list =
289293
FixedSizeListBuilder::new(momentum_location_prim, dim as i32);
290294

295+
let non_reversible = BooleanBuilder::new();
296+
291297
let msg_list = StringBuilder::new();
292298

293299
(
294300
Some(start_location_list),
295301
Some(start_grad_list),
296302
Some(end_location_list),
297303
Some(momentum_location_list),
304+
Some(non_reversible),
298305
Some(msg_list),
299306
)
300307
} else {
301-
(None, None, None, None, None)
308+
(None, None, None, None, None, None)
302309
};
303310

304311
Self {
@@ -320,6 +327,7 @@ impl<M: Math, A: AdaptStrategy<M>> NutsStatsBuilder<M, A> {
320327
divergence_start_grad: div_start_grad,
321328
divergence_end: div_end,
322329
divergence_momentum: div_mom,
330+
non_reversible: non_rev,
323331
divergence_msg: div_msg,
324332
}
325333
}
@@ -350,6 +358,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M
350358
divergence_start_grad,
351359
divergence_end,
352360
divergence_momentum,
361+
non_reversible,
353362
divergence_msg,
354363
} = self;
355364

@@ -414,6 +423,14 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M
414423
n_dim,
415424
);
416425

426+
if let Some(non_rev) = non_reversible.as_mut() {
427+
if let Some(info) = div_info {
428+
non_rev.append_value(info.non_reversible);
429+
} else {
430+
non_rev.append_null();
431+
}
432+
}
433+
417434
if let Some(div_msg) = divergence_msg.as_mut() {
418435
if let Some(err) = div_info.and_then(|info| info.logp_function_error.as_ref()) {
419436
div_msg.append_value(format!("{err}"));
@@ -447,6 +464,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M
447464
divergence_start_grad,
448465
divergence_end,
449466
divergence_momentum,
467+
non_reversible,
450468
divergence_msg,
451469
} = self;
452470

@@ -541,6 +559,11 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M
541559
&mut fields,
542560
);
543561

562+
if let Some(mut non_reversible) = non_reversible {
563+
fields.push(Field::new("non_reversible", DataType::Boolean, true));
564+
arrays.push(ArrayBuilder::finish(&mut non_reversible));
565+
}
566+
544567
let fields = Fields::from(fields);
545568
Some(StructArray::new(fields, arrays, None))
546569
}
@@ -565,6 +588,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M
565588
divergence_start_grad,
566589
divergence_end,
567590
divergence_momentum,
591+
non_reversible,
568592
divergence_msg,
569593
} = self;
570594

@@ -659,6 +683,11 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> StatTraceBuilder<M, NutsChain<M
659683
&mut fields,
660684
);
661685

686+
if let Some(non_reversible) = non_reversible {
687+
fields.push(Field::new("non_reversible", DataType::Boolean, true));
688+
arrays.push(ArrayBuilder::finish_cloned(non_reversible));
689+
}
690+
662691
let fields = Fields::from(fields);
663692
Some(StructArray::new(fields, arrays, None))
664693
}

src/euclidean_hamiltonian.rs

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
309309
math: &mut M,
310310
start: &State<M, Self::Point>,
311311
dir: Direction,
312-
step_size_factor: f64,
312+
step_size_splits: u64,
313313
collector: &mut C,
314314
) -> LeapfrogResult<M, Self::Point> {
315315
let mut out = self.pool().new_state(math);
@@ -322,7 +322,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
322322
Direction::Backward => -1,
323323
};
324324

325-
let epsilon = (sign as f64) * self.step_size * step_size_factor;
325+
let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64);
326326

327327
start
328328
.point()
@@ -334,17 +334,9 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
334334
if !logp_error.is_recoverable() {
335335
return LeapfrogResult::Err(logp_error);
336336
}
337-
let div_info = DivergenceInfo {
338-
logp_function_error: Some(Arc::new(Box::new(logp_error))),
339-
start_location: Some(math.box_array(start.point().position())),
340-
start_gradient: Some(math.box_array(&start.point().gradient)),
341-
start_momentum: Some(math.box_array(&start.point().momentum)),
342-
end_location: None,
343-
start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
344-
end_idx_in_trajectory: None,
345-
energy_error: None,
346-
};
347-
collector.register_leapfrog(math, start, &out, Some(&div_info));
337+
let error = Arc::new(Box::new(logp_error));
338+
let div_info = DivergenceInfo::new_logp_function_error(math, start, error);
339+
collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits);
348340
return LeapfrogResult::Divergence(div_info);
349341
}
350342

@@ -357,23 +349,21 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
357349

358350
start.point().set_psum(math, out_point, dir);
359351

352+
// TODO: energy error measured relative to initial point or previous point?
360353
let energy_error = out_point.energy_error();
361354
if (energy_error > self.max_energy_error) | !energy_error.is_finite() {
362-
let divergence_info = DivergenceInfo {
363-
logp_function_error: None,
364-
start_location: Some(math.box_array(start.point().position())),
365-
start_gradient: Some(math.box_array(start.point().gradient())),
366-
end_location: Some(math.box_array(&out_point.position)),
367-
start_momentum: Some(math.box_array(&out_point.momentum)),
368-
start_idx_in_trajectory: Some(start.index_in_trajectory()),
369-
end_idx_in_trajectory: Some(out.index_in_trajectory()),
370-
energy_error: Some(energy_error),
371-
};
372-
collector.register_leapfrog(math, start, &out, Some(&divergence_info));
355+
let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out);
356+
collector.register_leapfrog(
357+
math,
358+
start,
359+
&out,
360+
Some(&divergence_info),
361+
step_size_splits,
362+
);
373363
return LeapfrogResult::Divergence(divergence_info);
374364
}
375365

376-
collector.register_leapfrog(math, start, &out, None);
366+
collector.register_leapfrog(math, start, &out, None, step_size_splits);
377367

378368
LeapfrogResult::Ok(out)
379369
}
@@ -447,4 +437,8 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
447437
fn step_size_mut(&mut self) -> &mut f64 {
448438
&mut self.step_size
449439
}
440+
441+
fn max_energy_error(&self) -> f64 {
442+
self.max_energy_error
443+
}
450444
}

src/hamiltonian.rs

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::{
1616
/// a cutoff value or nan.
1717
/// - The logp function caused a recoverable error (eg if an ODE solver
1818
/// failed)
19+
#[non_exhaustive]
1920
#[derive(Debug, Clone)]
2021
pub struct DivergenceInfo {
2122
pub start_momentum: Option<Box<[f64]>>,
@@ -26,6 +27,7 @@ pub struct DivergenceInfo {
2627
pub end_idx_in_trajectory: Option<i64>,
2728
pub start_idx_in_trajectory: Option<i64>,
2829
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
30+
pub non_reversible: bool,
2931
}
3032

3133
impl DivergenceInfo {
@@ -39,8 +41,67 @@ impl DivergenceInfo {
3941
end_idx_in_trajectory: None,
4042
start_idx_in_trajectory: None,
4143
logp_function_error: None,
44+
non_reversible: false,
4245
}
4346
}
47+
48+
pub fn new_energy_error_too_large<M: Math>(
49+
math: &mut M,
50+
start: &State<M, impl Point<M>>,
51+
stop: &State<M, impl Point<M>>,
52+
) -> Self {
53+
DivergenceInfo {
54+
logp_function_error: None,
55+
start_location: Some(math.box_array(start.point().position())),
56+
start_gradient: Some(math.box_array(start.point().gradient())),
57+
// TODO
58+
start_momentum: None,
59+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
60+
end_location: Some(math.box_array(&stop.point().position())),
61+
end_idx_in_trajectory: Some(stop.index_in_trajectory()),
62+
// TODO
63+
energy_error: None,
64+
non_reversible: false,
65+
}
66+
}
67+
68+
pub fn new_logp_function_error<M: Math>(
69+
math: &mut M,
70+
start: &State<M, impl Point<M>>,
71+
logp_function_error: Arc<dyn std::error::Error + Send + Sync>,
72+
) -> Self {
73+
DivergenceInfo {
74+
logp_function_error: Some(logp_function_error),
75+
start_location: Some(math.box_array(start.point().position())),
76+
start_gradient: Some(math.box_array(start.point().gradient())),
77+
// TODO
78+
start_momentum: None,
79+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
80+
end_location: None,
81+
end_idx_in_trajectory: None,
82+
energy_error: None,
83+
non_reversible: false,
84+
}
85+
}
86+
87+
pub fn new_not_reversible<M: Math>(math: &mut M, start: &State<M, impl Point<M>>) -> Self {
88+
// TODO add info about what went wrong
89+
DivergenceInfo {
90+
logp_function_error: None,
91+
start_location: Some(math.box_array(start.point().position())),
92+
start_gradient: Some(math.box_array(start.point().gradient())),
93+
// TODO
94+
start_momentum: None,
95+
start_idx_in_trajectory: Some(start.index_in_trajectory()),
96+
end_location: None,
97+
end_idx_in_trajectory: None,
98+
energy_error: None,
99+
non_reversible: true,
100+
}
101+
}
102+
pub fn new_max_step_size_halvings<M: Math>(math: &mut M, num_steps: u64, info: Self) -> Self {
103+
info // TODO
104+
}
44105
}
45106

46107
#[derive(Debug, Copy, Clone)]
@@ -106,10 +167,44 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
106167
math: &mut M,
107168
start: &State<M, Self::Point>,
108169
dir: Direction,
109-
step_size_factor: f64,
170+
step_size_splits: u64,
110171
collector: &mut C,
111172
) -> LeapfrogResult<M, Self::Point>;
112173

174+
fn split_leapfrog<C: Collector<M, Self::Point>>(
175+
&mut self,
176+
math: &mut M,
177+
start: &State<M, Self::Point>,
178+
dir: Direction,
179+
num_steps: u64,
180+
collector: &mut C,
181+
max_error: f64,
182+
) -> LeapfrogResult<M, Self::Point> {
183+
let mut state = start.clone();
184+
185+
let mut min_energy = start.energy();
186+
let mut max_energy = min_energy;
187+
188+
for _ in 0..num_steps {
189+
state = match self.leapfrog(math, &state, dir, num_steps, collector) {
190+
LeapfrogResult::Ok(state) => state,
191+
LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info),
192+
LeapfrogResult::Err(err) => return LeapfrogResult::Err(err),
193+
};
194+
let energy = state.energy();
195+
min_energy = min_energy.min(energy);
196+
max_energy = max_energy.max(energy);
197+
198+
// TODO: walnuts papers says to use abs, but c++ code doesn't?
199+
if max_energy - min_energy > max_error {
200+
let info = DivergenceInfo::new_energy_error_too_large(math, start, &state);
201+
return LeapfrogResult::Divergence(info);
202+
}
203+
}
204+
205+
LeapfrogResult::Ok(state)
206+
}
207+
113208
fn is_turning(
114209
&self,
115210
math: &mut M,
@@ -141,4 +236,6 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
141236

142237
fn step_size(&self) -> f64;
143238
fn step_size_mut(&mut self) -> &mut f64;
239+
240+
fn max_energy_error(&self) -> f64;
144241
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ pub use chain::Chain;
108108
pub use cpu_math::{CpuLogpFunc, CpuMath};
109109
pub use hamiltonian::DivergenceInfo;
110110
pub use math_base::{LogpError, Math};
111-
pub use nuts::NutsError;
111+
pub use nuts::{NutsError, WalnutsOptions};
112112
pub use sampler::{
113113
sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage,
114114
LowRankNutsSettings, Model, NutsSettings, Progress, ProgressCallback, Sampler,

0 commit comments

Comments
 (0)