diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 79939e4..b1eca0f 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -291,11 +291,12 @@ where start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + num_substeps: u64, ) { self.collector1 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); self.collector2 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); } fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { @@ -482,6 +483,7 @@ mod test { store_unconstrained: true, check_turning: true, store_divergences: false, + walnuts_options: None, }; let rng = { diff --git a/src/chain.rs b/src/chain.rs index 026312a..f8f5981 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -157,6 +157,7 @@ where &mut self.hamiltonian, &self.options, &mut self.collector, + self.draw_count < 70, )?; let mut position: Box<[f64]> = vec![0f64; math.dim()].into(); state.write_position(math, &mut position); @@ -235,6 +236,7 @@ pub struct NutsStats, A: Storable

, D: Storable

> pub divergence_end: Option>, #[storable(dims("unconstrained_parameter"))] pub divergence_momentum: Option>, + non_reversible: Option, //pub divergence_message: Option, #[storable(ignore)] _phantom: PhantomData P>, @@ -303,7 +305,7 @@ impl> SamplerStats for NutsChain> Hamiltonian for EuclideanHamiltonian, dir: Direction, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -237,7 +238,7 @@ impl> Hamiltonian for EuclideanHamiltonian -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -249,17 +250,9 @@ impl> Hamiltonian for EuclideanHamiltonian> Hamiltonian for EuclideanHamiltonian self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(&out_point.position)), - start_momentum: Some(math.box_array(&out_point.momentum)), - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -362,4 +353,8 @@ impl> Hamiltonian for EuclideanHamiltonian &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } } diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index e4abcf8..b8632ae 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -16,6 +16,7 @@ use crate::{ /// a cutoff value or nan. /// - The logp function caused a recoverable error (eg if an ODE solver /// failed) +#[non_exhaustive] #[derive(Debug, Clone)] pub struct DivergenceInfo { pub start_momentum: Option>, @@ -26,6 +27,81 @@ pub struct DivergenceInfo { pub end_idx_in_trajectory: Option, pub start_idx_in_trajectory: Option, pub logp_function_error: Option>, + pub non_reversible: bool, +} + +impl DivergenceInfo { + pub fn new() -> Self { + DivergenceInfo { + start_momentum: None, + start_location: None, + start_gradient: None, + end_location: None, + energy_error: None, + end_idx_in_trajectory: None, + start_idx_in_trajectory: None, + logp_function_error: None, + non_reversible: false, + } + } + + pub fn new_energy_error_too_large( + math: &mut M, + start: &State>, + stop: &State>, + ) -> Self { + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: Some(math.box_array(&stop.point().position())), + end_idx_in_trajectory: Some(stop.index_in_trajectory()), + // TODO + energy_error: None, + non_reversible: false, + } + } + + pub fn new_logp_function_error( + math: &mut M, + start: &State>, + logp_function_error: Arc, + ) -> Self { + DivergenceInfo { + logp_function_error: Some(logp_function_error), + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: false, + } + } + + pub fn new_not_reversible(math: &mut M, start: &State>) -> Self { + // TODO add info about what went wrong + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: true, + } + } + pub fn new_max_step_size_halvings(math: &mut M, num_steps: u64, info: Self) -> Self { + info // TODO + } } #[derive(Debug, Copy, Clone)] @@ -34,6 +110,15 @@ pub enum Direction { Backward, } +impl Direction { + pub fn reverse(&self) -> Self { + match self { + Direction::Forward => Direction::Backward, + Direction::Backward => Direction::Forward, + } + } +} + impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> Direction { if rng.random::() { @@ -82,9 +167,44 @@ pub trait Hamiltonian: SamplerStats + Sized { math: &mut M, start: &State, dir: Direction, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult; + fn split_leapfrog>( + &mut self, + math: &mut M, + start: &State, + dir: Direction, + num_steps: u64, + collector: &mut C, + max_error: f64, + ) -> LeapfrogResult { + let mut state = start.clone(); + + let mut min_energy = start.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + state = match self.leapfrog(math, &state, dir, num_steps, collector) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info), + LeapfrogResult::Err(err) => return LeapfrogResult::Err(err), + }; + let energy = state.energy(); + min_energy = min_energy.min(energy); + max_energy = max_energy.max(energy); + + // TODO: walnuts papers says to use abs, but c++ code doesn't? + if max_energy - min_energy > max_error { + let info = DivergenceInfo::new_energy_error_too_large(math, start, &state); + return LeapfrogResult::Divergence(info); + } + } + + LeapfrogResult::Ok(state) + } + fn is_turning( &self, math: &mut M, @@ -116,4 +236,6 @@ pub trait Hamiltonian: SamplerStats + Sized { fn step_size(&self) -> f64; fn step_size_mut(&mut self) -> &mut f64; + + fn max_energy_error(&self) -> f64; } diff --git a/src/lib.rs b/src/lib.rs index b722803..c0e8c8a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,7 +125,7 @@ pub use cpu_math::{CpuLogpFunc, CpuMath, CpuMathError}; pub use hamiltonian::DivergenceInfo; pub use math_base::{LogpError, Math}; pub use model::Model; -pub use nuts::NutsError; +pub use nuts::{NutsError, WalnutsOptions}; pub use sampler::{ ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress, ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings, diff --git a/src/nuts.rs b/src/nuts.rs index 6c5af53..8cfd4e7 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -1,3 +1,4 @@ +use serde::Serialize; use thiserror::Error; use std::{fmt::Debug, marker::PhantomData}; @@ -34,6 +35,7 @@ pub trait Collector> { _start: &State, _end: &State, _divergence_info: Option<&DivergenceInfo>, + _num_substeps: u64, ) { } fn register_draw(&mut self, _math: &mut M, _state: &State, _info: &SampleInfo) {} @@ -59,22 +61,41 @@ pub struct SampleInfo { } /// A part of the trajectory tree during NUTS sampling. +/// +/// Corresponds to SpanW in walnuts C++ code struct NutsTree, C: Collector> { /// The left position of the tree. /// /// The left side always has the smaller index_in_trajectory. /// Leapfrogs in backward direction will replace the left. + /// + /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code left: State, + + /// The right position of the tree. + /// + /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code right: State, /// A draw from the trajectory between left and right using /// multinomial sampling. + /// + /// theta_select_ in C++ code draw: State, + + /// Constant for acceptance probability + /// + /// logp_ in C++ code log_size: f64, + + /// The depth of the tree depth: u64, /// A tree is the main tree if it contains the initial point /// of the trajectory. + /// + /// This is used to determine whether to use Metropolis + /// accptance or Barker is_main: bool, _phantom2: PhantomData, } @@ -115,20 +136,23 @@ impl, C: Collector> NutsTree { direction: Direction, collector: &mut C, options: &NutsOptions, + early: bool, ) -> ExtendResult where H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, hamiltonian, direction, collector) { - Ok(Ok(tree)) => tree, - Ok(Err(info)) => return ExtendResult::Diverging(self, info), - Err(err) => return ExtendResult::Err(err), - }; + let mut other = + match self.single_step(math, hamiltonian, direction, options, collector, early) { + Ok(Ok(tree)) => tree, + Ok(Err(info)) => return ExtendResult::Diverging(self, info), + Err(err) => return ExtendResult::Err(err), + }; while other.depth < self.depth { use ExtendResult::*; - other = match other.extend(math, rng, hamiltonian, direction, collector, options) { + other = match other.extend(math, rng, hamiltonian, direction, collector, options, early) + { Ok(tree) => tree, Turning(_) => { return Turning(self); @@ -171,6 +195,7 @@ impl, C: Collector> NutsTree { } } + // `combine` in C++ code fn merge_into( &mut self, _math: &mut M, @@ -208,24 +233,109 @@ impl, C: Collector> NutsTree { self.log_size = log_size; } + // Corresponds to `build_leaf` in C++ code fn single_step( &self, math: &mut M, hamiltonian: &mut H, direction: Direction, + options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result, DivergenceInfo>> { let start = match direction { Direction::Forward => &self.right, Direction::Backward => &self.left, }; - let end = match hamiltonian.leapfrog(math, start, direction, collector) { - LeapfrogResult::Divergence(info) => return Ok(Err(info)), - LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), - LeapfrogResult::Ok(end) => end, + + let (log_size, end) = match options.walnuts_options { + Some(ref options) => { + // Walnuts implementation + // TODO: Shouldn't all be in this one big function... + let mut num_steps = 1; + let mut current = start.clone(); + + let mut last_divergence = None; + + for _ in 0..options.max_step_size_halvings { + current = match hamiltonian.split_leapfrog( + math, + start, + direction, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(state) => { + last_divergence = None; + state + } + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Divergence(info) => { + num_steps *= 2; + last_divergence = Some(info); + continue; + } + }; + break; + } + + if let Some(info) = last_divergence { + let info = DivergenceInfo::new_max_step_size_halvings(math, num_steps, info); + return Ok(Err(info)); + } + + let back = direction.reverse(); + let mut reversible = true; + + while num_steps >= 2 { + num_steps /= 2; + + match hamiltonian.split_leapfrog( + math, + ¤t, + back, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(_) => (), + LeapfrogResult::Divergence(_) => { + // We also reject in the backward direction, all is good so far... + continue; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); + } + }; + + // We did not reject in the backward direction, so we are not reversible + reversible = false; + break; + } + + if reversible || early { + let log_size = -current.point().energy_error(); + (log_size, current) + } else { + return Ok(Err(DivergenceInfo::new_not_reversible(math, start))); + } + } + None => { + // Classical NUTS. + // TODO Is equivalent to walnuts with max_step_size_halvings = 0? + let end = match hamiltonian.leapfrog(math, start, direction, 1, collector) { + LeapfrogResult::Divergence(info) => return Ok(Err(info)), + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Ok(end) => end, + }; + + let log_size = -end.point().energy_error(); + + (log_size, end) + } }; - let log_size = -end.point().energy_error(); Ok(Ok(NutsTree { right: end.clone(), left: end.clone(), @@ -248,6 +358,23 @@ impl, C: Collector> NutsTree { } } +#[non_exhaustive] +#[derive(Debug, Clone, Copy, Serialize)] +pub struct WalnutsOptions { + pub max_step_size_halvings: u64, + pub max_energy_error: f64, +} + +impl Default for WalnutsOptions { + fn default() -> Self { + WalnutsOptions { + max_step_size_halvings: 10, + max_energy_error: 5.0, + } + } +} + +#[derive(Debug, Clone, Copy)] pub struct NutsOptions { pub maxdepth: u64, pub mindepth: u64, @@ -255,6 +382,8 @@ pub struct NutsOptions { pub store_unconstrained: bool, pub check_turning: bool, pub store_divergences: bool, + + pub walnuts_options: Option, } pub(crate) fn draw( @@ -264,6 +393,7 @@ pub(crate) fn draw( hamiltonian: &mut H, options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result<(State, SampleInfo)> where M: Math, @@ -284,7 +414,7 @@ where while tree.depth < options.maxdepth { let direction: Direction = rng.random(); - tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { + tree = match tree.extend(math, rng, hamiltonian, direction, collector, options, early) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { if tree.depth < options.mindepth { diff --git a/src/sampler.rs b/src/sampler.rs index 0ccb9ca..72b1fc0 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -20,17 +20,17 @@ use std::{ }; use crate::{ - DiagAdaptExpSettings, + DiagAdaptExpSettings, Model, SamplerStats, adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions}, chain::{AdaptStrategy, Chain, NutsChain, StatOptions}, euclidean_hamiltonian::EuclideanHamiltonian, - mass_matrix::DiagMassMatrix, - mass_matrix::Strategy as DiagMassMatrixStrategy, - mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings}, + mass_matrix::{ + DiagMassMatrix, LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings, + Strategy as DiagMassMatrixStrategy, + }, math_base::Math, - model::Model, - nuts::NutsOptions, - sampler_stats::{SamplerStats, StatsDims}, + nuts::{NutsOptions, WalnutsOptions}, + sampler_stats::StatsDims, storage::{ChainStorage, StorageConfig, TraceStorage}, transform_adapt_strategy::{TransformAdaptation, TransformedSettings}, transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions}, @@ -185,6 +185,7 @@ pub struct NutsSettings { pub num_chains: usize, pub seed: u64, + pub walnuts_options: Option, } pub type DiagGradNutsSettings = NutsSettings>; @@ -206,6 +207,7 @@ impl Default for DiagGradNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, } } } @@ -225,6 +227,7 @@ impl Default for LowRankNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, }; vals.adapt_options.mass_matrix_update_freq = 10; vals @@ -246,6 +249,7 @@ impl Default for TransformedNutsSettings { check_turning: true, seed: 0, num_chains: 1, + walnuts_options: None, } } } @@ -278,6 +282,7 @@ impl Settings for LowRankNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -346,6 +351,7 @@ impl Settings for DiagGradNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -411,6 +417,7 @@ impl Settings for TransformedNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); diff --git a/src/stepsize/adapt.rs b/src/stepsize/adapt.rs index 7dc9134..49cdd77 100644 --- a/src/stepsize/adapt.rs +++ b/src/stepsize/adapt.rs @@ -103,7 +103,7 @@ impl Strategy { *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, 0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { return Ok(()); @@ -119,7 +119,7 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, dir, 0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { *hamiltonian.step_size_mut() = self.options.initial_step; return Ok(()); diff --git a/src/stepsize/dual_avg.rs b/src/stepsize/dual_avg.rs index 3f6d613..7c989a0 100644 --- a/src/stepsize/dual_avg.rs +++ b/src/stepsize/dual_avg.rs @@ -126,6 +126,7 @@ impl> Collector for AcceptanceRateCollector { _start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + _num_substeps: u64, ) { match divergence_info { Some(_) => { diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index 6360ab1..307c0fa 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -81,6 +81,7 @@ impl> Collector for DrawCollector { _start: &State, end: &State, divergence_info: Option<&crate::DivergenceInfo>, + num_substeps: u64, ) { if divergence_info.is_some() { return; diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 7b97482..9f0cf27 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -303,6 +303,7 @@ impl Hamiltonian for TransformedHamiltonian { math: &mut M, start: &State, dir: Direction, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -316,7 +317,7 @@ impl Hamiltonian for TransformedHamiltonian { Direction::Backward => -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -327,17 +328,9 @@ impl Hamiltonian for TransformedHamiltonian { if !logp_error.is_recoverable() { return LeapfrogResult::Err(logp_error); } - let div_info = DivergenceInfo { - logp_function_error: Some(Arc::new(Box::new(logp_error))), - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - start_momentum: None, - end_location: None, - start_idx_in_trajectory: Some(start.point().index_in_trajectory()), - end_idx_in_trajectory: None, - energy_error: None, - }; - collector.register_leapfrog(math, start, &out, Some(&div_info)); + let logp_error = Arc::new(Box::new(logp_error)); + let div_info = DivergenceInfo::new_logp_function_error(math, start, logp_error); + collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits); return LeapfrogResult::Divergence(div_info); } @@ -348,21 +341,18 @@ impl Hamiltonian for TransformedHamiltonian { let energy_error = out_point.energy_error(); if (energy_error > self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(out_point.position())), - start_momentum: None, - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -464,4 +454,8 @@ impl Hamiltonian for TransformedHamiltonian { fn step_size_mut(&mut self) -> &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } }