Skip to content

Commit 68c08ba

Browse files
committed
fix: correctly specify dims of some sample stats
1 parent d859b82 commit 68c08ba

File tree

5 files changed

+32
-13
lines changed

5 files changed

+32
-13
lines changed

src/adapt_strategy.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ pub struct GlobalStrategyStats<P: HasDims, S: Storable<P>, M: Storable<P>> {
210210
pub step_size: S,
211211
#[storable(flatten)]
212212
pub mass_matrix: M,
213+
pub tuning: bool,
213214
#[storable(ignore)]
214215
_phantom: std::marker::PhantomData<fn() -> P>,
215216
}
@@ -243,6 +244,7 @@ where
243244
self.step_size.extract_stats(math, ())
244245
},
245246
mass_matrix: self.mass_matrix.extract_stats(math, opt.mass_matrix),
247+
tuning: self.tuning,
246248
_phantom: PhantomData,
247249
}
248250
}

src/chain.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ pub struct NutsStats<P: HasDims, H: Storable<P>, A: Storable<P>, D: Storable<P>>
217217
pub draw: u64,
218218
pub energy_error: f64,
219219
#[storable(dims("unconstrained_parameter"))]
220-
pub unconstrained: Option<Vec<f64>>,
220+
pub unconstrained_draw: Option<Vec<f64>>,
221221
#[storable(dims("unconstrained_parameter"))]
222222
pub gradient: Option<Vec<f64>>,
223223
#[storable(flatten)]
@@ -289,7 +289,7 @@ impl<M: Math, R: rand::Rng, A: AdaptStrategy<M>> SamplerStats<M> for NutsChain<M
289289
chain: self.chain,
290290
draw: self.draw_count,
291291
energy_error: point.energy_error(),
292-
unconstrained: Some(math.box_array(point.position()).into_vec()),
292+
unconstrained_draw: Some(math.box_array(point.position()).into_vec()),
293293
gradient: Some(math.box_array(point.gradient()).into_vec()),
294294
hamiltonian: hamiltonian_stats,
295295
adapt: adapt_stats,

src/mass_matrix/low_rank.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::VecDeque;
2+
use std::iter::repeat;
23

34
use faer::{Col, ColRef, Mat, MatRef, Scale};
45
use itertools::Itertools;
@@ -127,15 +128,17 @@ impl Default for LowRankSettings {
127128
Self {
128129
store_mass_matrix: false,
129130
gamma: 1e-5,
130-
eigval_cutoff: 100f64,
131+
eigval_cutoff: 2f64,
131132
}
132133
}
133134
}
134135

135136
#[derive(Debug, Storable)]
136137
pub struct MatrixStats {
137-
pub eigvals: Option<Vec<f64>>,
138-
pub stds: Option<Vec<f64>>,
138+
#[storable(dims("unconstrained_parameter"))]
139+
pub mass_matrix_eigvals: Option<Vec<f64>>,
140+
#[storable(dims("unconstrained_parameter"))]
141+
pub mass_matrix_stds: Option<Vec<f64>>,
139142
pub num_eigenvalues: u64,
140143
}
141144

@@ -145,14 +148,18 @@ impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
145148

146149
fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
147150
if self.settings.store_mass_matrix {
151+
let stds = Some(math.box_array(&self.stds));
148152
let eigvals = self
149153
.inner
150154
.as_ref()
151155
.map(|inner| math.eigs_as_array(&inner.vals));
152-
let stds = Some(math.box_array(&self.stds));
156+
let mut eigvals = eigvals.map(|x| x.into_vec());
157+
if let Some(ref mut eigvals) = eigvals {
158+
eigvals.extend(repeat(f64::NAN).take(stds.as_ref().unwrap().len() - eigvals.len()));
159+
}
153160
MatrixStats {
154-
eigvals: eigvals.map(|x| x.into_vec()),
155-
stds: stds.map(|x| x.into_vec()),
161+
mass_matrix_eigvals: eigvals,
162+
mass_matrix_stds: stds.map(|x| x.into_vec()),
156163
num_eigenvalues: self
157164
.inner
158165
.as_ref()
@@ -161,9 +168,13 @@ impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
161168
}
162169
} else {
163170
MatrixStats {
164-
eigvals: None,
165-
stds: None,
166-
num_eigenvalues: 0,
171+
mass_matrix_eigvals: None,
172+
mass_matrix_stds: None,
173+
num_eigenvalues: self
174+
.inner
175+
.as_ref()
176+
.map(|inner| inner.num_eigenvalues)
177+
.unwrap_or(0),
167178
}
168179
}
169180
}

src/transform_adapt_strategy.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,18 @@ pub struct TransformAdaptation {
4343
}
4444

4545
#[derive(Debug, Storable)]
46-
pub struct Stats {}
46+
pub struct Stats {
47+
tuning: bool,
48+
}
4749

4850
impl<M: Math> SamplerStats<M> for TransformAdaptation {
4951
type Stats = Stats;
5052
type StatsOptions = ();
5153

5254
fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {
53-
Stats {}
55+
Stats {
56+
tuning: self.tuning,
57+
}
5458
}
5559
}
5660

src/transformed_hamiltonian.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ pub struct TransformedPoint<M: Math> {
2626
#[derive(Debug, Storable)]
2727
pub struct PointStats {
2828
pub fisher_distance: f64,
29+
#[storable(dims("unconstrained_parameter"))]
2930
pub transformed_position: Option<Vec<f64>>,
31+
#[storable(dims("unconstrained_parameter"))]
3032
pub transformed_gradient: Option<Vec<f64>>,
3133
pub transformation_index: i64,
3234
}

0 commit comments

Comments
 (0)