Skip to content

Commit 069d28d

Browse files
committed
feat: Add low-rank modified mass matrix adaptation
1 parent 179af4b commit 069d28d

File tree

7 files changed

+589
-146
lines changed

7 files changed

+589
-146
lines changed

src/adapt_strategy.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,10 @@ mod test {
381381
use super::*;
382382
use crate::{
383383
cpu_math::CpuMath,
384+
mass_matrix::DiagMassMatrix,
384385
nuts::{AdaptStrategy, Chain, NutsChain, NutsOptions},
386+
potential::EuclideanPotential,
387+
DiagAdaptExpSettings,
385388
};
386389

387390
#[test]

src/cpu_math.rs

Lines changed: 99 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use std::{error::Error, fmt::Debug};
1+
use std::{error::Error, fmt::Debug, mem::replace};
22

33
use faer::{Col, Mat};
4-
use itertools::izip;
4+
use itertools::{izip, Itertools};
55
use thiserror::Error;
66

77
use crate::{
@@ -33,16 +33,35 @@ pub enum CpuMathError {
3333
impl<F: CpuLogpFunc> Math for CpuMath<F> {
3434
type Vector = Col<f64>;
3535
type EigVectors = Mat<f64>;
36-
type EigValues = Mat<f64>;
36+
type EigValues = Col<f64>;
3737
type LogpErr = F::LogpError;
3838
type Err = CpuMathError;
3939

4040
fn new_array(&self) -> Self::Vector {
4141
Col::zeros(self.dim())
4242
}
4343

44-
fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
45-
self.logp_func.logp(position, gradient)
44+
fn new_eig_vectors<'a>(
45+
&'a mut self,
46+
vals: impl ExactSizeIterator<Item = &'a [f64]>,
47+
) -> Self::EigVectors {
48+
let ndim = self.dim();
49+
let nvecs = vals.len();
50+
51+
let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
52+
vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
53+
col.try_as_slice_mut()
54+
.expect("Array is not contiguous")
55+
.copy_from_slice(vals)
56+
});
57+
58+
vectors
59+
}
60+
61+
fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
62+
let mut values: Col<f64> = Col::zeros(vals.len());
63+
values.as_slice_mut().copy_from_slice(vals);
64+
values
4665
}
4766

4867
fn logp_array(
@@ -54,6 +73,10 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
5473
.logp(position.as_slice(), gradient.as_slice_mut())
5574
}
5675

76+
fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
77+
self.logp_func.logp(position, gradient)
78+
}
79+
5780
fn dim(&self) -> usize {
5881
self.logp_func.dim()
5982
}
@@ -136,6 +159,22 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
136159
multiply(array1.as_slice(), array2.as_slice(), dest.as_slice_mut())
137160
}
138161

162+
fn array_mult_eigs(
163+
&mut self,
164+
stds: &Self::Vector,
165+
rhs: &Self::Vector,
166+
dest: &mut Self::Vector,
167+
vecs: &Self::EigVectors,
168+
vals: &Self::EigValues,
169+
) {
170+
let rhs = stds.column_vector_as_diagonal() * rhs;
171+
let trafo = vecs.transpose() * (&rhs);
172+
let inner_prod = vecs * (vals.column_vector_as_diagonal() * (&trafo) - (&trafo)) + rhs;
173+
let scaled = stds.column_vector_as_diagonal() * inner_prod;
174+
175+
let _ = replace(dest, scaled);
176+
}
177+
139178
fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
140179
vector_dot(array1.as_slice(), array2.as_slice())
141180
}
@@ -156,6 +195,28 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
156195
});
157196
}
158197

198+
fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
199+
&mut self,
200+
rng: &mut R,
201+
dest: &mut Self::Vector,
202+
scale: &Self::Vector,
203+
vals: &Self::EigValues,
204+
vecs: &Self::EigVectors,
205+
) {
206+
let mut draw: Col<f64> = Col::zeros(self.dim());
207+
let dist = rand_distr::StandardNormal;
208+
draw.as_slice_mut().iter_mut().for_each(|p| {
209+
*p = rng.sample(dist);
210+
});
211+
212+
let trafo = vecs.transpose() * (&draw);
213+
let inner_prod = vecs * (vals.column_vector_as_diagonal() * (&trafo) - (&trafo)) + draw;
214+
215+
let scaled = scale.column_vector_as_diagonal() * inner_prod;
216+
217+
let _ = replace(dest, scaled);
218+
}
219+
159220
fn array_update_variance(
160221
&mut self,
161222
mean: &mut Self::Vector,
@@ -177,6 +238,37 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
177238
})
178239
}
179240

241+
fn array_update_var_inv_std_draw(
242+
&mut self,
243+
variance_out: &mut Self::Vector,
244+
inv_std: &mut Self::Vector,
245+
draw_var: &Self::Vector,
246+
scale: f64,
247+
fill_invalid: Option<f64>,
248+
clamp: (f64, f64),
249+
) {
250+
self.arch.dispatch(|| {
251+
izip!(
252+
variance_out.as_slice_mut().iter_mut(),
253+
inv_std.as_slice_mut().iter_mut(),
254+
draw_var.as_slice().iter(),
255+
)
256+
.for_each(|(var_out, inv_std_out, &draw_var)| {
257+
let draw_var = draw_var * scale;
258+
if (!draw_var.is_finite()) | (draw_var == 0f64) {
259+
if let Some(fill_val) = fill_invalid {
260+
*var_out = fill_val;
261+
*inv_std_out = fill_val.recip().sqrt();
262+
}
263+
} else {
264+
let val = draw_var.clamp(clamp.0, clamp.1);
265+
*var_out = val;
266+
*inv_std_out = val.recip().sqrt();
267+
}
268+
});
269+
});
270+
}
271+
180272
fn array_update_var_inv_std_draw_grad(
181273
&mut self,
182274
variance_out: &mut Self::Vector,
@@ -232,56 +324,8 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
232324
});
233325
}
234326

235-
fn array_update_var_inv_std_draw(
236-
&mut self,
237-
variance_out: &mut Self::Vector,
238-
inv_std: &mut Self::Vector,
239-
draw_var: &Self::Vector,
240-
scale: f64,
241-
fill_invalid: Option<f64>,
242-
clamp: (f64, f64),
243-
) {
244-
self.arch.dispatch(|| {
245-
izip!(
246-
variance_out.as_slice_mut().iter_mut(),
247-
inv_std.as_slice_mut().iter_mut(),
248-
draw_var.as_slice().iter(),
249-
)
250-
.for_each(|(var_out, inv_std_out, &draw_var)| {
251-
let draw_var = draw_var * scale;
252-
if (!draw_var.is_finite()) | (draw_var == 0f64) {
253-
if let Some(fill_val) = fill_invalid {
254-
*var_out = fill_val;
255-
*inv_std_out = fill_val.recip().sqrt();
256-
}
257-
} else {
258-
let val = draw_var.clamp(clamp.0, clamp.1);
259-
*var_out = val;
260-
*inv_std_out = val.recip().sqrt();
261-
}
262-
});
263-
});
264-
}
265-
266-
fn new_eig_vectors<'a>(
267-
&'a mut self,
268-
vals: impl ExactSizeIterator<Item = &'a [f64]>,
269-
) -> Self::EigVectors {
270-
todo!()
271-
}
272-
273-
fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
274-
todo!()
275-
}
276-
277-
fn scaled_eigval_matmul(
278-
&mut self,
279-
scale: &Self::Vector,
280-
vals: &Self::EigValues,
281-
vecs: &Self::EigVectors,
282-
out: &mut Self::Vector,
283-
) {
284-
todo!()
327+
fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
328+
source.as_slice().to_vec().into()
285329
}
286330
}
287331

src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@ pub use cpu_math::{CpuLogpFunc, CpuMath};
102102
pub use math_base::Math;
103103
pub use nuts::{Chain, DivergenceInfo, LogpError, NutsError, SampleStats};
104104
pub use sampler::{
105-
sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage, Model,
106-
ProgressCallback, Sampler, SamplerWaitResult, Settings, Trace,
105+
sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage,
106+
LowRankNutsSettings, Model, NutsSettings, ProgressCallback, Sampler, SamplerWaitResult,
107+
Settings, Trace,
107108
};
108109

110+
pub use low_rank_mass_matrix::LowRankSettings;
109111
pub use mass_matrix_adapt::DiagAdaptExpSettings;
110112
pub use stepsize_adapt::DualAverageSettings;

0 commit comments

Comments
 (0)