Skip to content

Commit 1752c63

Browse files
committed
feat: Add transforming adaptation
1 parent 0f246f2 commit 1752c63

15 files changed

+347
-230
lines changed

src/adapt_strategy.rs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
9090
>;
9191
type Options = EuclideanAdaptOptions<A::Options>;
9292

93-
fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self {
93+
fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
9494
let num_tune_f = num_tune as f64;
9595
let step_size_window = (options.step_size_window * num_tune_f) as u64;
9696
let early_end = (options.early_window * num_tune_f) as u64;
@@ -100,7 +100,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
100100

101101
Self {
102102
step_size: StepSizeStrategy::new(options.dual_average_options),
103-
mass_matrix: A::new(math, options.mass_matrix_options, num_tune),
103+
mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain),
104104
options,
105105
num_tune,
106106
early_end,
@@ -116,13 +116,13 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
116116
math: &mut M,
117117
options: &mut NutsOptions,
118118
hamiltonian: &mut Self::Hamiltonian,
119-
state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
119+
position: &[f64],
120120
rng: &mut R,
121121
) -> Result<(), NutsError> {
122122
self.mass_matrix
123-
.init(math, options, hamiltonian, state, rng)?;
123+
.init(math, options, hamiltonian, position, rng)?;
124124
self.step_size
125-
.init(math, options, hamiltonian, state, rng)?;
125+
.init(math, options, hamiltonian, position, rng)?;
126126
Ok(())
127127
}
128128

@@ -186,8 +186,9 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
186186
// First time we change the mass matrix
187187
if did_change & self.has_initial_mass_matrix {
188188
self.has_initial_mass_matrix = false;
189+
let position = math.box_array(state.point().position());
189190
self.step_size
190-
.init(math, options, hamiltonian, state, rng)?;
191+
.init(math, options, hamiltonian, &position, rng)?;
191192
} else {
192193
self.step_size.update_stepsize(hamiltonian, false)
193194
}
@@ -403,7 +404,18 @@ pub mod test_logps {
403404
unimplemented!()
404405
}
405406

406-
fn transformed_logp(
407+
fn init_from_transformed_position(
408+
&mut self,
409+
_params: &Self::TransformParams,
410+
_untransformed_position: &mut [f64],
411+
_untransformed_gradient: &mut [f64],
412+
_transformed_position: &[f64],
413+
_transformed_gradient: &mut [f64],
414+
) -> Result<(f64, f64), Self::LogpError> {
415+
unimplemented!()
416+
}
417+
418+
fn init_from_untransformed_position(
407419
&mut self,
408420
_params: &Self::TransformParams,
409421
_untransformed_position: &[f64],
@@ -424,15 +436,19 @@ pub mod test_logps {
424436
unimplemented!()
425437
}
426438

427-
fn new_transformation(
439+
fn new_transformation<R: rand::Rng + ?Sized>(
428440
&mut self,
441+
_rng: &mut R,
429442
_untransformed_position: &[f64],
430443
_untransfogmed_gradient: &[f64],
431444
) -> Result<Self::TransformParams, Self::LogpError> {
432445
unimplemented!()
433446
}
434447

435-
fn transformation_id(&self, _params: &Self::TransformParams) -> i64 {
448+
fn transformation_id(
449+
&self,
450+
_params: &Self::TransformParams,
451+
) -> Result<i64, Self::LogpError> {
436452
unimplemented!()
437453
}
438454
}
@@ -462,7 +478,8 @@ mod test {
462478
let max_energy_error = 1000f64;
463479
let step_size = 0.1f64;
464480

465-
let hamiltonian = EuclideanHamiltonian::new(mass_matrix, max_energy_error, step_size);
481+
let hamiltonian =
482+
EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size);
466483
let options = NutsOptions {
467484
maxdepth: 10u64,
468485
store_gradient: true,

src/chain.rs

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{
66
hamiltonian::{Hamiltonian, Point},
77
nuts::{draw, Collector, NutsOptions, NutsSampleStats, NutsStatsBuilder},
88
sampler_stats::SamplerStats,
9-
state::{State, StatePool},
9+
state::State,
1010
Math, NutsError, Settings,
1111
};
1212

@@ -35,7 +35,6 @@ where
3535
R: rand::Rng,
3636
A: AdaptStrategy<M>,
3737
{
38-
pool: StatePool<M, <A::Hamiltonian as Hamiltonian<M>>::Point>,
3938
hamiltonian: A::Hamiltonian,
4039
collector: A::Collector,
4140
options: NutsOptions,
@@ -56,18 +55,15 @@ where
5655
{
5756
pub fn new(
5857
mut math: M,
59-
hamiltonian: A::Hamiltonian,
58+
mut hamiltonian: A::Hamiltonian,
6059
strategy: A,
6160
options: NutsOptions,
6261
rng: R,
6362
chain: u64,
6463
) -> Self {
65-
let pool_size: usize = options.maxdepth.checked_mul(2).unwrap().try_into().unwrap();
66-
let pool = hamiltonian.new_pool(&mut math, pool_size);
67-
let init = pool.new_state(&mut math);
64+
let init = hamiltonian.pool().new_state(&mut math);
6865
let collector = strategy.new_collector(&mut math);
6966
NutsChain {
70-
pool,
7167
hamiltonian,
7268
collector,
7369
options,
@@ -87,14 +83,14 @@ pub trait AdaptStrategy<M: Math>: SamplerStats<M> {
8783
type Collector: Collector<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>;
8884
type Options: Copy + Send + Debug + Default;
8985

90-
fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self;
86+
fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self;
9187

9288
fn init<R: Rng + ?Sized>(
9389
&mut self,
9490
math: &mut M,
9591
options: &mut NutsOptions,
9692
hamiltonian: &mut Self::Hamiltonian,
97-
state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
93+
position: &[f64],
9894
rng: &mut R,
9995
) -> Result<(), NutsError>;
10096

@@ -151,24 +147,20 @@ where
151147
type AdaptStrategy = A;
152148

153149
fn set_position(&mut self, position: &[f64]) -> Result<()> {
154-
let state = self
155-
.hamiltonian
156-
.init_state(&mut self.math, &mut self.pool, position)?;
157-
self.init = state;
158150
self.strategy.init(
159151
&mut self.math,
160152
&mut self.options,
161153
&mut self.hamiltonian,
162-
&self.init,
154+
position,
163155
&mut self.rng,
164156
)?;
157+
self.init = self.hamiltonian.init_state(&mut self.math, position)?;
165158
Ok(())
166159
}
167160

168161
fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)> {
169162
let (state, info) = draw(
170163
&mut self.math,
171-
&mut self.pool,
172164
&mut self.init,
173165
&mut self.rng,
174166
&mut self.hamiltonian,

src/cpu_math.rs

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,15 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
360360
)
361361
}
362362

363-
fn transformed_logp(
363+
fn init_from_untransformed_position(
364364
&mut self,
365365
params: &Self::TransformParams,
366366
untransformed_position: &Self::Vector,
367367
untransformed_gradient: &mut Self::Vector,
368368
transformed_position: &mut Self::Vector,
369369
transformed_gradient: &mut Self::Vector,
370370
) -> Result<(f64, f64), Self::LogpErr> {
371-
self.logp_func.transformed_logp(
371+
self.logp_func.init_from_untransformed_position(
372372
params,
373373
untransformed_position.as_slice(),
374374
untransformed_gradient.as_slice_mut(),
@@ -377,11 +377,28 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
377377
)
378378
}
379379

380+
fn init_from_transformed_position(
381+
&mut self,
382+
params: &Self::TransformParams,
383+
untransformed_position: &mut Self::Vector,
384+
untransformed_gradient: &mut Self::Vector,
385+
transformed_position: &Self::Vector,
386+
transformed_gradient: &mut Self::Vector,
387+
) -> Result<(f64, f64), Self::LogpErr> {
388+
self.logp_func.init_from_transformed_position(
389+
params,
390+
untransformed_position.as_slice_mut(),
391+
untransformed_gradient.as_slice_mut(),
392+
transformed_position.as_slice(),
393+
transformed_gradient.as_slice_mut(),
394+
)
395+
}
396+
380397
fn update_transformation<'a, R: rand::Rng + ?Sized>(
381398
&'a mut self,
382399
rng: &mut R,
383-
untransformed_positions: impl Iterator<Item = &'a Self::Vector>,
384-
untransformed_gradients: impl Iterator<Item = &'a Self::Vector>,
400+
untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
401+
untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
385402
params: &'a mut Self::TransformParams,
386403
) -> Result<(), Self::LogpErr> {
387404
self.logp_func.update_transformation(
@@ -392,18 +409,22 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
392409
)
393410
}
394411

395-
fn new_transformation(
412+
fn new_transformation<R: rand::Rng + ?Sized>(
396413
&mut self,
414+
rng: &mut R,
397415
untransformed_position: &Self::Vector,
398416
untransfogmed_gradient: &Self::Vector,
417+
chain: u64,
399418
) -> Result<Self::TransformParams, Self::LogpErr> {
400419
self.logp_func.new_transformation(
420+
rng,
401421
untransformed_position.as_slice(),
402422
untransfogmed_gradient.as_slice(),
423+
chain,
403424
)
404425
}
405426

406-
fn transformation_id(&self, params: &Self::TransformParams) -> i64 {
427+
fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr> {
407428
self.logp_func.transformation_id(params)
408429
}
409430
}
@@ -417,35 +438,58 @@ pub trait CpuLogpFunc {
417438

418439
fn inv_transform_normalize(
419440
&mut self,
420-
params: &Self::TransformParams,
421-
untransformed_position: &[f64],
422-
untransofrmed_gradient: &[f64],
423-
transformed_position: &mut [f64],
424-
transformed_gradient: &mut [f64],
425-
) -> Result<f64, Self::LogpError>;
441+
_params: &Self::TransformParams,
442+
_untransformed_position: &[f64],
443+
_untransformed_gradient: &[f64],
444+
_transformed_position: &mut [f64],
445+
_transformed_gradient: &mut [f64],
446+
) -> Result<f64, Self::LogpError> {
447+
unimplemented!()
448+
}
426449

427-
fn transformed_logp(
450+
fn init_from_untransformed_position(
428451
&mut self,
429-
params: &Self::TransformParams,
430-
untransformed_position: &[f64],
431-
untransformed_gradient: &mut [f64],
432-
transformed_position: &mut [f64],
433-
transformed_gradient: &mut [f64],
434-
) -> Result<(f64, f64), Self::LogpError>;
452+
_params: &Self::TransformParams,
453+
_untransformed_position: &[f64],
454+
_untransformed_gradient: &mut [f64],
455+
_transformed_position: &mut [f64],
456+
_transformed_gradient: &mut [f64],
457+
) -> Result<(f64, f64), Self::LogpError> {
458+
unimplemented!()
459+
}
460+
461+
fn init_from_transformed_position(
462+
&mut self,
463+
_params: &Self::TransformParams,
464+
_untransformed_position: &mut [f64],
465+
_untransformed_gradient: &mut [f64],
466+
_transformed_position: &[f64],
467+
_transformed_gradient: &mut [f64],
468+
) -> Result<(f64, f64), Self::LogpError> {
469+
unimplemented!()
470+
}
435471

436472
fn update_transformation<'a, R: rand::Rng + ?Sized>(
437473
&'a mut self,
438-
rng: &mut R,
439-
untransformed_positions: impl Iterator<Item = &'a [f64]>,
440-
untransformed_gradients: impl Iterator<Item = &'a [f64]>,
441-
params: &'a mut Self::TransformParams,
442-
) -> Result<(), Self::LogpError>;
474+
_rng: &mut R,
475+
_untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
476+
_untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
477+
_params: &'a mut Self::TransformParams,
478+
) -> Result<(), Self::LogpError> {
479+
unimplemented!()
480+
}
443481

444-
fn new_transformation(
482+
fn new_transformation<R: rand::Rng + ?Sized>(
445483
&mut self,
446-
untransformed_position: &[f64],
447-
untransfogmed_gradient: &[f64],
448-
) -> Result<Self::TransformParams, Self::LogpError>;
484+
_rng: &mut R,
485+
_untransformed_position: &[f64],
486+
_untransformed_gradient: &[f64],
487+
_chain: u64,
488+
) -> Result<Self::TransformParams, Self::LogpError> {
489+
unimplemented!()
490+
}
449491

450-
fn transformation_id(&self, params: &Self::TransformParams) -> i64;
492+
fn transformation_id(&self, _params: &Self::TransformParams) -> Result<i64, Self::LogpError> {
493+
unimplemented!()
494+
}
451495
}

0 commit comments

Comments
 (0)