Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,12 @@ where
start: &State<M, P>,
end: &State<M, P>,
divergence_info: Option<&DivergenceInfo>,
num_substeps: u64,
) {
self.collector1
.register_leapfrog(math, start, end, divergence_info);
.register_leapfrog(math, start, end, divergence_info, num_substeps);
self.collector2
.register_leapfrog(math, start, end, divergence_info);
.register_leapfrog(math, start, end, divergence_info, num_substeps);
}

fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
Expand Down Expand Up @@ -482,6 +483,7 @@ mod test {
store_unconstrained: true,
check_turning: true,
store_divergences: false,
walnuts_options: None,
};

let rng = {
Expand Down
4 changes: 3 additions & 1 deletion src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ where
&mut self.hamiltonian,
&self.options,
&mut self.collector,
self.draw_count < 70,
)?;
let mut position: Box<[f64]> = vec![0f64; math.dim()].into();
state.write_position(math, &mut position);
Expand Down Expand Up @@ -235,6 +236,7 @@ pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>>
pub divergence_end: Option<Vec<f64>>,
#[storable(dims("unconstrained_parameter"))]
pub divergence_momentum: Option<Vec<f64>>,
non_reversible: Option<bool>,
//pub divergence_message: Option<String>,
#[storable(ignore)]
_phantom: PhantomData<fn() -> P>,
Expand Down Expand Up @@ -303,7 +305,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M
.and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec())),
divergence_momentum: div_info
.and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec())),
//divergence_message: self.divergence_msg.clone(),
non_reversible: div_info.and_then(|d| Some(d.non_reversible)),
_phantom: PhantomData,
}
}
Expand Down
43 changes: 19 additions & 24 deletions src/euclidean_hamiltonian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
math: &mut M,
start: &State<M, Self::Point>,
dir: Direction,
step_size_splits: u64,
collector: &mut C,
) -> LeapfrogResult<M, Self::Point> {
let mut out = self.pool().new_state(math);
Expand All @@ -237,7 +238,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
Direction::Backward => -1,
};

let epsilon = (sign as f64) * self.step_size;
let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64);

start
.point()
Expand All @@ -249,17 +250,9 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
if !logp_error.is_recoverable() {
return LeapfrogResult::Err(logp_error);
}
let div_info = DivergenceInfo {
logp_function_error: Some(Arc::new(Box::new(logp_error))),
start_location: Some(math.box_array(start.point().position())),
start_gradient: Some(math.box_array(&start.point().gradient)),
start_momentum: Some(math.box_array(&start.point().momentum)),
end_location: None,
start_idx_in_trajectory: Some(start.point().index_in_trajectory()),
end_idx_in_trajectory: None,
energy_error: None,
};
collector.register_leapfrog(math, start, &out, Some(&div_info));
let error = Arc::new(Box::new(logp_error));
let div_info = DivergenceInfo::new_logp_function_error(math, start, error);
collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits);
return LeapfrogResult::Divergence(div_info);
}

Expand All @@ -272,23 +265,21 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma

start.point().set_psum(math, out_point, dir);

// TODO: energy error measured relative to initial point or previous point?
let energy_error = out_point.energy_error();
if (energy_error > self.max_energy_error) | !energy_error.is_finite() {
let divergence_info = DivergenceInfo {
logp_function_error: None,
start_location: Some(math.box_array(start.point().position())),
start_gradient: Some(math.box_array(start.point().gradient())),
end_location: Some(math.box_array(&out_point.position)),
start_momentum: Some(math.box_array(&out_point.momentum)),
start_idx_in_trajectory: Some(start.index_in_trajectory()),
end_idx_in_trajectory: Some(out.index_in_trajectory()),
energy_error: Some(energy_error),
};
collector.register_leapfrog(math, start, &out, Some(&divergence_info));
let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out);
collector.register_leapfrog(
math,
start,
&out,
Some(&divergence_info),
step_size_splits,
);
return LeapfrogResult::Divergence(divergence_info);
}

collector.register_leapfrog(math, start, &out, None);
collector.register_leapfrog(math, start, &out, None, step_size_splits);

LeapfrogResult::Ok(out)
}
Expand Down Expand Up @@ -362,4 +353,8 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
fn step_size_mut(&mut self) -> &mut f64 {
&mut self.step_size
}

fn max_energy_error(&self) -> f64 {
self.max_energy_error
}
}
122 changes: 122 additions & 0 deletions src/hamiltonian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
/// a cutoff value or nan.
/// - The logp function caused a recoverable error (eg if an ODE solver
/// failed)
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct DivergenceInfo {
pub start_momentum: Option<Box<[f64]>>,
Expand All @@ -26,6 +27,81 @@ pub struct DivergenceInfo {
pub end_idx_in_trajectory: Option<i64>,
pub start_idx_in_trajectory: Option<i64>,
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
pub non_reversible: bool,
}

impl DivergenceInfo {
pub fn new() -> Self {
DivergenceInfo {
start_momentum: None,
start_location: None,
start_gradient: None,
end_location: None,
energy_error: None,
end_idx_in_trajectory: None,
start_idx_in_trajectory: None,
logp_function_error: None,
non_reversible: false,
}
}

pub fn new_energy_error_too_large<M: Math>(
math: &mut M,
start: &State<M, impl Point<M>>,
stop: &State<M, impl Point<M>>,
) -> Self {
DivergenceInfo {
logp_function_error: None,
start_location: Some(math.box_array(start.point().position())),
start_gradient: Some(math.box_array(start.point().gradient())),
// TODO
start_momentum: None,
start_idx_in_trajectory: Some(start.index_in_trajectory()),
end_location: Some(math.box_array(&stop.point().position())),
end_idx_in_trajectory: Some(stop.index_in_trajectory()),
// TODO
energy_error: None,
non_reversible: false,
}
}

pub fn new_logp_function_error<M: Math>(
math: &mut M,
start: &State<M, impl Point<M>>,
logp_function_error: Arc<dyn std::error::Error + Send + Sync>,
) -> Self {
DivergenceInfo {
logp_function_error: Some(logp_function_error),
start_location: Some(math.box_array(start.point().position())),
start_gradient: Some(math.box_array(start.point().gradient())),
// TODO
start_momentum: None,
start_idx_in_trajectory: Some(start.index_in_trajectory()),
end_location: None,
end_idx_in_trajectory: None,
energy_error: None,
non_reversible: false,
}
}

pub fn new_not_reversible<M: Math>(math: &mut M, start: &State<M, impl Point<M>>) -> Self {
// TODO add info about what went wrong
DivergenceInfo {
logp_function_error: None,
start_location: Some(math.box_array(start.point().position())),
start_gradient: Some(math.box_array(start.point().gradient())),
// TODO
start_momentum: None,
start_idx_in_trajectory: Some(start.index_in_trajectory()),
end_location: None,
end_idx_in_trajectory: None,
energy_error: None,
non_reversible: true,
}
}
pub fn new_max_step_size_halvings<M: Math>(math: &mut M, num_steps: u64, info: Self) -> Self {
info // TODO
}
}

#[derive(Debug, Copy, Clone)]
Expand All @@ -34,6 +110,15 @@ pub enum Direction {
Backward,
}

impl Direction {
pub fn reverse(&self) -> Self {
match self {
Direction::Forward => Direction::Backward,
Direction::Backward => Direction::Forward,
}
}
}

impl Distribution<Direction> for StandardUniform {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Direction {
if rng.random::<bool>() {
Expand Down Expand Up @@ -82,9 +167,44 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
math: &mut M,
start: &State<M, Self::Point>,
dir: Direction,
step_size_splits: u64,
collector: &mut C,
) -> LeapfrogResult<M, Self::Point>;

fn split_leapfrog<C: Collector<M, Self::Point>>(
&mut self,
math: &mut M,
start: &State<M, Self::Point>,
dir: Direction,
num_steps: u64,
collector: &mut C,
max_error: f64,
) -> LeapfrogResult<M, Self::Point> {
let mut state = start.clone();

let mut min_energy = start.energy();
let mut max_energy = min_energy;

for _ in 0..num_steps {
state = match self.leapfrog(math, &state, dir, num_steps, collector) {
LeapfrogResult::Ok(state) => state,
LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info),
LeapfrogResult::Err(err) => return LeapfrogResult::Err(err),
};
let energy = state.energy();
min_energy = min_energy.min(energy);
max_energy = max_energy.max(energy);

// TODO: walnuts papers says to use abs, but c++ code doesn't?
if max_energy - min_energy > max_error {
let info = DivergenceInfo::new_energy_error_too_large(math, start, &state);
return LeapfrogResult::Divergence(info);
}
}

LeapfrogResult::Ok(state)
}

fn is_turning(
&self,
math: &mut M,
Expand Down Expand Up @@ -116,4 +236,6 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {

fn step_size(&self) -> f64;
fn step_size_mut(&mut self) -> &mut f64;

fn max_energy_error(&self) -> f64;
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ pub use cpu_math::{CpuLogpFunc, CpuMath, CpuMathError};
pub use hamiltonian::DivergenceInfo;
pub use math_base::{LogpError, Math};
pub use model::Model;
pub use nuts::NutsError;
pub use nuts::{NutsError, WalnutsOptions};
pub use sampler::{
ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress,
ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings,
Expand Down
Loading
Loading