Skip to content

Commit fb63fb1

Browse files
committed
refactor: Refactor mass matrix adaptation traits
1 parent 1752c63 commit fb63fb1

File tree

8 files changed

+132
-148
lines changed

8 files changed

+132
-148
lines changed

src/adapt_strategy.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use rand::Rng;
66

77
use crate::{
88
chain::AdaptStrategy,
9+
euclidean_hamiltonian::EuclideanHamiltonian,
910
hamiltonian::{DivergenceInfo, Hamiltonian, Point},
1011
mass_matrix_adapt::MassMatrixAdaptStrategy,
1112
math_base::Math,
@@ -81,7 +82,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<
8182
}
8283

8384
impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
84-
type Hamiltonian = A::Hamiltonian;
85+
type Hamiltonian = EuclideanHamiltonian<M, A::MassMatrix>;
8586
type Collector = CombinedCollector<
8687
M,
8788
<Self::Hamiltonian as Hamiltonian<M>>::Point,
@@ -119,8 +120,14 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
119120
position: &[f64],
120121
rng: &mut R,
121122
) -> Result<(), NutsError> {
122-
self.mass_matrix
123-
.init(math, options, hamiltonian, position, rng)?;
123+
let state = hamiltonian.init_state(math, position)?;
124+
self.mass_matrix.init(
125+
math,
126+
options,
127+
&mut hamiltonian.mass_matrix,
128+
state.point(),
129+
rng,
130+
)?;
124131
self.step_size
125132
.init(math, options, hamiltonian, position, rng)?;
126133
Ok(())
@@ -168,7 +175,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
168175
let did_change = if force_update
169176
| (draw - self.last_update >= self.options.mass_matrix_update_freq)
170177
{
171-
self.mass_matrix.update_potential(math, hamiltonian)
178+
self.mass_matrix.adapt(math, &mut hamiltonian.mass_matrix)
172179
} else {
173180
false
174181
};
@@ -221,8 +228,8 @@ pub struct CombinedStats<D1, D2> {
221228

222229
#[derive(Clone)]
223230
pub struct CombinedStatsBuilder<B1, B2> {
224-
stats1: B1,
225-
stats2: B2,
231+
pub stats1: B1,
232+
pub stats2: B2,
226233
}
227234

228235
impl<S1, S2, B1, B2> StatTraceBuilder<CombinedStats<S1, S2>> for CombinedStatsBuilder<B1, B2>
@@ -441,6 +448,7 @@ pub mod test_logps {
441448
_rng: &mut R,
442449
_untransformed_position: &[f64],
443450
_untransfogmed_gradient: &[f64],
451+
_chain: u64,
444452
) -> Result<Self::TransformParams, Self::LogpError> {
445453
unimplemented!()
446454
}
@@ -472,7 +480,7 @@ mod test {
472480
let mut math = CpuMath::new(func);
473481
let num_tune = 100;
474482
let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
475-
let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune);
483+
let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune, 0u64);
476484

477485
let mass_matrix = DiagMassMatrix::new(&mut math, true);
478486
let max_energy_error = 1000f64;

src/low_rank_mass_matrix.rs

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ use faer::{Col, Mat, Scale};
88
use itertools::Itertools;
99

1010
use crate::{
11-
chain::AdaptStrategy,
12-
euclidean_hamiltonian::{EuclideanHamiltonian, EuclideanPoint},
13-
hamiltonian::{Hamiltonian, Point},
11+
euclidean_hamiltonian::EuclideanPoint,
12+
hamiltonian::Point,
1413
mass_matrix::{DrawGradCollector, MassMatrix},
1514
mass_matrix_adapt::MassMatrixAdaptStrategy,
1615
sampler_stats::{SamplerStats, StatTraceBuilder},
17-
state::State,
1816
Math, NutsError,
1917
};
2018

@@ -392,12 +390,12 @@ impl LowRankMassMatrixStrategy {
392390
}
393391
}
394392

395-
pub fn add_draw<M: Math>(&mut self, math: &mut M, state: &State<M, EuclideanPoint<M>>) {
393+
pub fn add_draw<M: Math>(&mut self, math: &mut M, point: &impl Point<M>) {
396394
assert!(math.dim() == self.ndim);
397395
let mut draw = vec![0f64; self.ndim];
398-
state.write_position(math, &mut draw);
396+
math.write_to_slice(point.position(), &mut draw);
399397
let mut grad = vec![0f64; self.ndim];
400-
state.write_gradient(math, &mut grad);
398+
math.write_to_slice(point.gradient(), &mut grad);
401399

402400
self.draws.push_back(draw);
403401
self.grads.push_back(grad);
@@ -569,8 +567,8 @@ impl<M: Math> SamplerStats<M> for LowRankMassMatrixStrategy {
569567
}
570568
}
571569

572-
impl<M: Math> AdaptStrategy<M> for LowRankMassMatrixStrategy {
573-
type Hamiltonian = EuclideanHamiltonian<M, LowRankMassMatrix<M>>;
570+
impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
571+
type MassMatrix = LowRankMassMatrix<M>;
574572
type Collector = DrawGradCollector<M>;
575573
type Options = LowRankSettings;
576574

@@ -582,46 +580,19 @@ impl<M: Math> AdaptStrategy<M> for LowRankMassMatrixStrategy {
582580
&mut self,
583581
math: &mut M,
584582
_options: &mut crate::nuts::NutsOptions,
585-
hamiltonian: &mut Self::Hamiltonian,
586-
position: &[f64],
587-
_rng: &mut R,
588-
) -> Result<(), NutsError> {
589-
let state = hamiltonian.init_state(math, position)?;
590-
self.add_draw(math, &state);
591-
hamiltonian.mass_matrix.update_from_grad(
592-
math,
593-
state.point().gradient(),
594-
1f64,
595-
(1e-20, 1e20),
596-
);
597-
Ok(())
598-
}
599-
600-
fn adapt<R: rand::Rng + ?Sized>(
601-
&mut self,
602-
_math: &mut M,
603-
_options: &mut crate::nuts::NutsOptions,
604-
_potential: &mut Self::Hamiltonian,
605-
_draw: u64,
606-
_collector: &Self::Collector,
607-
_state: &State<M, EuclideanPoint<M>>,
583+
mass_matrix: &mut Self::MassMatrix,
584+
point: &impl Point<M>,
608585
_rng: &mut R,
609586
) -> Result<(), NutsError> {
587+
self.add_draw(math, point);
588+
mass_matrix.update_from_grad(math, point.gradient(), 1f64, (1e-20, 1e20));
610589
Ok(())
611590
}
612591

613592
fn new_collector(&self, math: &mut M) -> Self::Collector {
614593
DrawGradCollector::new(math)
615594
}
616595

617-
fn is_tuning(&self) -> bool {
618-
unreachable!()
619-
}
620-
}
621-
622-
impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
623-
type MassMatrix = LowRankMassMatrix<M>;
624-
625596
fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector) {
626597
if collector.is_good {
627598
let mut draw = vec![0f64; self.ndim];
@@ -651,11 +622,11 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
651622
self.draws.len().checked_sub(self.background_split).unwrap() as u64
652623
}
653624

654-
fn update_potential(&self, math: &mut M, potential: &mut Self::Hamiltonian) -> bool {
625+
fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool {
655626
if <LowRankMassMatrixStrategy as MassMatrixAdaptStrategy<M>>::current_count(self) < 3 {
656627
return false;
657628
}
658-
self.update(math, &mut potential.mass_matrix);
629+
self.update(math, mass_matrix);
659630

660631
true
661632
}

src/mass_matrix_adapt.rs

Lines changed: 45 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@ use std::marker::PhantomData;
33
use rand::Rng;
44

55
use crate::{
6-
chain::AdaptStrategy,
7-
euclidean_hamiltonian::{EuclideanHamiltonian, EuclideanPoint},
8-
hamiltonian::{Hamiltonian, Point},
6+
euclidean_hamiltonian::EuclideanPoint,
7+
hamiltonian::Point,
98
mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance},
10-
nuts::NutsOptions,
9+
nuts::{Collector, NutsOptions},
1110
sampler_stats::SamplerStats,
12-
state::State,
1311
Math, NutsError, Settings,
1412
};
1513
const LOWER_LIMIT: f64 = 1e-20f64;
@@ -43,8 +41,10 @@ pub struct Strategy<M: Math> {
4341
_phantom: PhantomData<M>,
4442
}
4543

46-
pub trait MassMatrixAdaptStrategy<M: Math>: AdaptStrategy<M> {
44+
pub trait MassMatrixAdaptStrategy<M: Math>: SamplerStats<M> {
4745
type MassMatrix: MassMatrix<M>;
46+
type Collector: Collector<M, EuclideanPoint<M>>;
47+
type Options: std::fmt::Debug + Default + Clone + Send + Sync + Copy;
4848

4949
fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector);
5050

@@ -55,11 +55,26 @@ pub trait MassMatrixAdaptStrategy<M: Math>: AdaptStrategy<M> {
5555
fn background_count(&self) -> u64;
5656

5757
/// Give the opportunity to update the potential and return if it was changed
58-
fn update_potential(&self, math: &mut M, potential: &mut Self::Hamiltonian) -> bool;
58+
fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool;
59+
60+
fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self;
61+
62+
fn init<R: Rng + ?Sized>(
63+
&mut self,
64+
math: &mut M,
65+
_options: &mut NutsOptions,
66+
mass_matrix: &mut Self::MassMatrix,
67+
point: &impl Point<M>,
68+
_rng: &mut R,
69+
) -> Result<(), NutsError>;
70+
71+
fn new_collector(&self, math: &mut M) -> Self::Collector;
5972
}
6073

6174
impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
6275
type MassMatrix = DiagMassMatrix<M>;
76+
type Collector = DrawGradCollector<M>;
77+
type Options = DiagAdaptExpSettings;
6378

6479
fn update_estimators(&mut self, math: &mut M, collector: &DrawGradCollector<M>) {
6580
if collector.is_good {
@@ -88,11 +103,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
88103
}
89104

90105
/// Give the opportunity to update the potential and return if it was changed
91-
fn update_potential(
92-
&self,
93-
math: &mut M,
94-
potential: &mut EuclideanHamiltonian<M, Self::MassMatrix>,
95-
) -> bool {
106+
fn adapt(&self, math: &mut M, mass_matrix: &mut DiagMassMatrix<M>) -> bool {
96107
if self.current_count() < 3 {
97108
return false;
98109
}
@@ -102,7 +113,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
102113
assert!(draw_scale == grad_scale);
103114

104115
if self._settings.use_grad_based_estimate {
105-
potential.mass_matrix.update_diag_draw_grad(
116+
mass_matrix.update_diag_draw_grad(
106117
math,
107118
draw_var,
108119
grad_var,
@@ -111,35 +122,11 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
111122
);
112123
} else {
113124
let scale = (self.exp_variance_draw.count() as f64).recip();
114-
potential.mass_matrix.update_diag_draw(
115-
math,
116-
draw_var,
117-
scale,
118-
None,
119-
(LOWER_LIMIT, UPPER_LIMIT),
120-
);
125+
mass_matrix.update_diag_draw(math, draw_var, scale, None, (LOWER_LIMIT, UPPER_LIMIT));
121126
}
122127

123128
true
124129
}
125-
}
126-
127-
pub type Stats = ();
128-
pub type StatsBuilder = ();
129-
130-
impl<M: Math> SamplerStats<M> for Strategy<M> {
131-
type Builder = Stats;
132-
type Stats = StatsBuilder;
133-
134-
fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {}
135-
136-
fn current_stats(&self, _math: &mut M) -> Self::Stats {}
137-
}
138-
139-
impl<M: Math> AdaptStrategy<M> for Strategy<M> {
140-
type Hamiltonian = EuclideanHamiltonian<M, DiagMassMatrix<M>>;
141-
type Collector = DrawGradCollector<M>;
142-
type Options = DiagAdaptExpSettings;
143130

144131
fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self {
145132
Self {
@@ -156,49 +143,37 @@ impl<M: Math> AdaptStrategy<M> for Strategy<M> {
156143
&mut self,
157144
math: &mut M,
158145
_options: &mut NutsOptions,
159-
hamiltonian: &mut Self::Hamiltonian,
160-
position: &[f64],
146+
mass_matrix: &mut Self::MassMatrix,
147+
point: &impl Point<M>,
161148
_rng: &mut R,
162149
) -> Result<(), NutsError> {
163-
let state = hamiltonian.init_state(math, position)?;
164-
165-
self.exp_variance_draw
166-
.add_sample(math, state.point().position());
167-
self.exp_variance_draw_bg
168-
.add_sample(math, state.point().position());
169-
self.exp_variance_grad
170-
.add_sample(math, state.point().gradient());
171-
self.exp_variance_grad_bg
172-
.add_sample(math, state.point().gradient());
173-
174-
hamiltonian.mass_matrix.update_diag_grad(
150+
self.exp_variance_draw.add_sample(math, point.position());
151+
self.exp_variance_draw_bg.add_sample(math, point.position());
152+
self.exp_variance_grad.add_sample(math, point.gradient());
153+
self.exp_variance_grad_bg.add_sample(math, point.gradient());
154+
155+
mass_matrix.update_diag_grad(
175156
math,
176-
state.point().gradient(),
157+
point.gradient(),
177158
1f64,
178159
(INIT_LOWER_LIMIT, INIT_UPPER_LIMIT),
179160
);
180161
Ok(())
181162
}
182163

183-
fn adapt<R: Rng + ?Sized>(
184-
&mut self,
185-
_math: &mut M,
186-
_options: &mut NutsOptions,
187-
_potential: &mut Self::Hamiltonian,
188-
_draw: u64,
189-
_collector: &Self::Collector,
190-
_state: &State<M, EuclideanPoint<M>>,
191-
_rng: &mut R,
192-
) -> Result<(), NutsError> {
193-
// Must be controlled from a different meta strategy
194-
Ok(())
195-
}
196-
197164
fn new_collector(&self, math: &mut M) -> Self::Collector {
198165
DrawGradCollector::new(math)
199166
}
167+
}
200168

201-
fn is_tuning(&self) -> bool {
202-
unreachable!()
203-
}
169+
pub type Stats = ();
170+
pub type StatsBuilder = ();
171+
172+
impl<M: Math> SamplerStats<M> for Strategy<M> {
173+
type Builder = Stats;
174+
type Stats = StatsBuilder;
175+
176+
fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {}
177+
178+
fn current_stats(&self, _math: &mut M) -> Self::Stats {}
204179
}

src/nuts.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
235235
LeapfrogResult::Ok(end) => end,
236236
};
237237

238-
// TODO sign?
239238
let log_size = -end.point().energy_error();
240239
Ok(Ok(NutsTree {
241240
right: end.clone(),

src/sampler.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,8 @@ impl Settings for TransformedNutsSettings {
315315
&self,
316316
stats: &<Self::Chain<M> as SamplerStats<M>>::Stats,
317317
) -> SampleStats {
318-
// TODO
319-
let step_size = 0.;
320-
let num_steps = 0;
318+
let step_size = stats.potential_stats.step_size;
319+
let num_steps = stats.strategy_stats.step_size.n_steps;
321320
SampleStats {
322321
chain: stats.chain,
323322
draw: stats.draw,
@@ -1065,6 +1064,7 @@ pub mod test_logps {
10651064
_rng: &mut R,
10661065
_untransformed_position: &[f64],
10671066
_untransfogmed_gradient: &[f64],
1067+
_chain: u64,
10681068
) -> std::result::Result<Self::TransformParams, Self::LogpError> {
10691069
unimplemented!()
10701070
}

0 commit comments

Comments
 (0)