Skip to content

Commit 0f246f2

Browse files
committed
More transform refactoring
1 parent 05b78cf commit 0f246f2

21 files changed

+2177
-1139
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ rand_distr = "0.4.3"
2323
multiversion = "0.7.2"
2424
itertools = "0.13.0"
2525
thiserror = "1.0.43"
26-
arrow = { version = "53.0.0", default-features = false, features = ["ffi"] }
26+
arrow = { version = "53.1.0", default-features = false, features = ["ffi"] }
2727
rand_chacha = "0.3.1"
2828
anyhow = "1.0.72"
2929
faer = { version = "0.19.4", default-features = false, features = ["std"] }

src/adapt_strategy.rs

Lines changed: 126 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,26 @@ use itertools::Itertools;
55
use rand::Rng;
66

77
use crate::{
8+
chain::AdaptStrategy,
9+
hamiltonian::{DivergenceInfo, Hamiltonian, Point},
810
mass_matrix_adapt::MassMatrixAdaptStrategy,
911
math_base::Math,
10-
nuts::{AdaptStats, AdaptStrategy, Collector, NutsOptions},
12+
nuts::{Collector, NutsOptions},
1113
sampler::Settings,
14+
sampler_stats::{SamplerStats, StatTraceBuilder},
1215
state::State,
1316
stepsize::AcceptanceRateCollector,
1417
stepsize_adapt::{
1518
DualAverageSettings, Stats as StepSizeStats, StatsBuilder as StepSizeStatsBuilder,
1619
Strategy as StepSizeStrategy,
1720
},
18-
DivergenceInfo,
21+
NutsError,
1922
};
2023

21-
use crate::nuts::{SamplerStats, StatTraceBuilder};
22-
2324
pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
2425
step_size: StepSizeStrategy,
2526
mass_matrix: A,
26-
options: AdaptOptions<A::Options>,
27+
options: EuclideanAdaptOptions<A::Options>,
2728
num_tune: u64,
2829
// The number of draws in the the early window
2930
early_end: u64,
@@ -36,7 +37,7 @@ pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
3637
}
3738

3839
#[derive(Debug, Clone, Copy)]
39-
pub struct AdaptOptions<S: Debug + Default> {
40+
pub struct EuclideanAdaptOptions<S: Debug + Default> {
4041
pub dual_average_options: DualAverageSettings,
4142
pub mass_matrix_options: S,
4243
pub early_window: f64,
@@ -46,7 +47,7 @@ pub struct AdaptOptions<S: Debug + Default> {
4647
pub mass_matrix_update_freq: u64,
4748
}
4849

49-
impl<S: Debug + Default> Default for AdaptOptions<S> {
50+
impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
5051
fn default() -> Self {
5152
Self {
5253
dual_average_options: DualAverageSettings::default(),
@@ -79,16 +80,15 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<
7980
}
8081
}
8182

82-
impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStats<M> for GlobalStrategy<M, A> {
83-
fn num_grad_evals(stats: &Self::Stats) -> usize {
84-
stats.stats1.n_steps as usize
85-
}
86-
}
87-
8883
impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
89-
type Potential = A::Potential;
90-
type Collector = CombinedCollector<M, AcceptanceRateCollector, A::Collector>;
91-
type Options = AdaptOptions<A::Options>;
84+
type Hamiltonian = A::Hamiltonian;
85+
type Collector = CombinedCollector<
86+
M,
87+
<Self::Hamiltonian as Hamiltonian<M>>::Point,
88+
AcceptanceRateCollector,
89+
A::Collector,
90+
>;
91+
type Options = EuclideanAdaptOptions<A::Options>;
9292

9393
fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self {
9494
let num_tune_f = num_tune as f64;
@@ -115,29 +115,32 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
115115
&mut self,
116116
math: &mut M,
117117
options: &mut NutsOptions,
118-
potential: &mut Self::Potential,
119-
state: &State<M>,
118+
hamiltonian: &mut Self::Hamiltonian,
119+
state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
120120
rng: &mut R,
121-
) {
122-
self.mass_matrix.init(math, options, potential, state, rng);
123-
self.step_size.init(math, options, potential, state, rng);
121+
) -> Result<(), NutsError> {
122+
self.mass_matrix
123+
.init(math, options, hamiltonian, state, rng)?;
124+
self.step_size
125+
.init(math, options, hamiltonian, state, rng)?;
126+
Ok(())
124127
}
125128

126129
fn adapt<R: Rng + ?Sized>(
127130
&mut self,
128131
math: &mut M,
129132
options: &mut NutsOptions,
130-
potential: &mut Self::Potential,
133+
hamiltonian: &mut Self::Hamiltonian,
131134
draw: u64,
132135
collector: &Self::Collector,
133-
state: &State<M>,
136+
state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
134137
rng: &mut R,
135-
) {
138+
) -> Result<(), NutsError> {
136139
self.step_size.update(&collector.collector1);
137140

138141
if draw >= self.num_tune {
139142
self.tuning = false;
140-
return;
143+
return Ok(());
141144
}
142145

143146
if draw < self.final_step_size_window {
@@ -165,7 +168,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
165168
let did_change = if force_update
166169
| (draw - self.last_update >= self.options.mass_matrix_update_freq)
167170
{
168-
self.mass_matrix.update_potential(math, potential)
171+
self.mass_matrix.update_potential(math, hamiltonian)
169172
} else {
170173
false
171174
};
@@ -183,24 +186,25 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
183186
// First time we change the mass matrix
184187
if did_change & self.has_initial_mass_matrix {
185188
self.has_initial_mass_matrix = false;
186-
self.step_size.init(math, options, potential, state, rng);
189+
self.step_size
190+
.init(math, options, hamiltonian, state, rng)?;
187191
} else {
188-
self.step_size.update_stepsize(potential, false)
192+
self.step_size.update_stepsize(hamiltonian, false)
189193
}
190-
return;
194+
return Ok(());
191195
}
192196

193197
self.step_size.update_estimator_late();
194198
let is_last = draw == self.num_tune - 1;
195-
self.step_size.update_stepsize(potential, is_last);
199+
self.step_size.update_stepsize(hamiltonian, is_last);
200+
Ok(())
196201
}
197202

198203
fn new_collector(&self, math: &mut M) -> Self::Collector {
199-
CombinedCollector {
200-
collector1: self.step_size.new_collector(),
201-
collector2: self.mass_matrix.new_collector(math),
202-
_phantom: PhantomData,
203-
}
204+
Self::Collector::new(
205+
self.step_size.new_collector(),
206+
self.mass_matrix.new_collector(math),
207+
)
204208
}
205209

206210
fn is_tuning(&self) -> bool {
@@ -277,22 +281,48 @@ where
277281
}
278282
}
279283

280-
pub struct CombinedCollector<M: Math, C1: Collector<M>, C2: Collector<M>> {
281-
collector1: C1,
282-
collector2: C2,
284+
pub struct CombinedCollector<M, P, C1, C2>
285+
where
286+
M: Math,
287+
P: Point<M>,
288+
C1: Collector<M, P>,
289+
C2: Collector<M, P>,
290+
{
291+
pub collector1: C1,
292+
pub collector2: C2,
283293
_phantom: PhantomData<M>,
294+
_phantom2: PhantomData<P>,
284295
}
285296

286-
impl<M: Math, C1, C2> Collector<M> for CombinedCollector<M, C1, C2>
297+
impl<M, P, C1, C2> CombinedCollector<M, P, C1, C2>
287298
where
288-
C1: Collector<M>,
289-
C2: Collector<M>,
299+
M: Math,
300+
P: Point<M>,
301+
C1: Collector<M, P>,
302+
C2: Collector<M, P>,
303+
{
304+
pub fn new(collector1: C1, collector2: C2) -> Self {
305+
CombinedCollector {
306+
collector1,
307+
collector2,
308+
_phantom: PhantomData,
309+
_phantom2: PhantomData,
310+
}
311+
}
312+
}
313+
314+
impl<M, P, C1, C2> Collector<M, P> for CombinedCollector<M, P, C1, C2>
315+
where
316+
M: Math,
317+
P: Point<M>,
318+
C1: Collector<M, P>,
319+
C2: Collector<M, P>,
290320
{
291321
fn register_leapfrog(
292322
&mut self,
293323
math: &mut M,
294-
start: &State<M>,
295-
end: &State<M>,
324+
start: &State<M, P>,
325+
end: &State<M, P>,
296326
divergence_info: Option<&DivergenceInfo>,
297327
) {
298328
self.collector1
@@ -301,15 +331,15 @@ where
301331
.register_leapfrog(math, start, end, divergence_info);
302332
}
303333

304-
fn register_draw(&mut self, math: &mut M, state: &State<M>, info: &crate::nuts::SampleInfo) {
334+
fn register_draw(&mut self, math: &mut M, state: &State<M, P>, info: &crate::nuts::SampleInfo) {
305335
self.collector1.register_draw(math, state, info);
306336
self.collector2.register_draw(math, state, info);
307337
}
308338

309339
fn register_init(
310340
&mut self,
311341
math: &mut M,
312-
state: &State<M>,
342+
state: &State<M, P>,
313343
options: &crate::nuts::NutsOptions,
314344
) {
315345
self.collector1.register_init(math, state, options);
@@ -319,7 +349,7 @@ where
319349

320350
#[cfg(test)]
321351
pub mod test_logps {
322-
use crate::{cpu_math::CpuLogpFunc, nuts::LogpError};
352+
use crate::{cpu_math::CpuLogpFunc, math_base::LogpError};
323353
use thiserror::Error;
324354

325355
#[derive(Clone, Debug)]
@@ -344,6 +374,7 @@ pub mod test_logps {
344374

345375
impl CpuLogpFunc for NormalLogp {
346376
type LogpError = NormalLogpError;
377+
type TransformParams = ();
347378

348379
fn dim(&self) -> usize {
349380
self.dim
@@ -360,6 +391,50 @@ pub mod test_logps {
360391
}
361392
Ok(logp)
362393
}
394+
395+
fn inv_transform_normalize(
396+
&mut self,
397+
_params: &Self::TransformParams,
398+
_untransformed_position: &[f64],
399+
_untransofrmed_gradient: &[f64],
400+
_transformed_position: &mut [f64],
401+
_transformed_gradient: &mut [f64],
402+
) -> Result<f64, Self::LogpError> {
403+
unimplemented!()
404+
}
405+
406+
fn transformed_logp(
407+
&mut self,
408+
_params: &Self::TransformParams,
409+
_untransformed_position: &[f64],
410+
_untransformed_gradient: &mut [f64],
411+
_transformed_position: &mut [f64],
412+
_transformed_gradient: &mut [f64],
413+
) -> Result<(f64, f64), Self::LogpError> {
414+
unimplemented!()
415+
}
416+
417+
fn update_transformation<'a, R: rand::Rng + ?Sized>(
418+
&'a mut self,
419+
_rng: &mut R,
420+
_untransformed_positions: impl Iterator<Item = &'a [f64]>,
421+
_untransformed_gradients: impl Iterator<Item = &'a [f64]>,
422+
_params: &'a mut Self::TransformParams,
423+
) -> Result<(), Self::LogpError> {
424+
unimplemented!()
425+
}
426+
427+
fn new_transformation(
428+
&mut self,
429+
_untransformed_position: &[f64],
430+
_untransfogmed_gradient: &[f64],
431+
) -> Result<Self::TransformParams, Self::LogpError> {
432+
unimplemented!()
433+
}
434+
435+
fn transformation_id(&self, _params: &Self::TransformParams) -> i64 {
436+
unimplemented!()
437+
}
363438
}
364439
}
365440

@@ -368,11 +443,8 @@ mod test {
368443
use super::test_logps::NormalLogp;
369444
use super::*;
370445
use crate::{
371-
cpu_math::CpuMath,
372-
mass_matrix::DiagMassMatrix,
373-
nuts::{AdaptStrategy, Chain, NutsChain, NutsOptions},
374-
potential::EuclideanPotential,
375-
DiagAdaptExpSettings,
446+
chain::NutsChain, cpu_math::CpuMath, euclidean_hamiltonian::EuclideanHamiltonian,
447+
mass_matrix::DiagMassMatrix, Chain, DiagAdaptExpSettings,
376448
};
377449

378450
#[test]
@@ -383,14 +455,14 @@ mod test {
383455
let func = NormalLogp::new(ndim, 3.);
384456
let mut math = CpuMath::new(func);
385457
let num_tune = 100;
386-
let options = AdaptOptions::<DiagAdaptExpSettings>::default();
458+
let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
387459
let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune);
388460

389461
let mass_matrix = DiagMassMatrix::new(&mut math, true);
390462
let max_energy_error = 1000f64;
391463
let step_size = 0.1f64;
392464

393-
let potential = EuclideanPotential::new(mass_matrix, max_energy_error, step_size);
465+
let hamiltonian = EuclideanHamiltonian::new(mass_matrix, max_energy_error, step_size);
394466
let options = NutsOptions {
395467
maxdepth: 10u64,
396468
store_gradient: true,
@@ -405,7 +477,7 @@ mod test {
405477
};
406478
let chain = 0u64;
407479

408-
let mut sampler = NutsChain::new(math, potential, strategy, options, rng, chain);
480+
let mut sampler = NutsChain::new(math, hamiltonian, strategy, options, rng, chain);
409481
sampler.set_position(&vec![1.5f64; ndim]).unwrap();
410482
for _ in 0..200 {
411483
sampler.draw().unwrap();

0 commit comments

Comments
 (0)