Skip to content

Commit 9fc565b

Browse files
committed
refactor: Remove unnecessary stats structs and add some transform stats
1 parent 2a55602 commit 9fc565b

17 files changed

+933
-932
lines changed

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
name = "nuts-rs"
33
version = "0.13.0"
44
authors = [
5-
"Adrian Seyboldt <[email protected]>",
6-
"PyMC Developers <[email protected]>",
5+
"Adrian Seyboldt <[email protected]>",
6+
"PyMC Developers <[email protected]>",
77
]
88
edition = "2021"
99
license = "MIT"
@@ -22,12 +22,12 @@ rand = { version = "0.8.5", features = ["small_rng"] }
2222
rand_distr = "0.4.3"
2323
multiversion = "0.7.2"
2424
itertools = "0.13.0"
25-
thiserror = "1.0.43"
25+
thiserror = "2.0.3"
2626
arrow = { version = "53.1.0", default-features = false, features = ["ffi"] }
2727
rand_chacha = "0.3.1"
2828
anyhow = "1.0.72"
2929
faer = { version = "0.19.4", default-features = false, features = ["std"] }
30-
pulp = "0.18.21"
30+
pulp = "0.19.6"
3131
rayon = "1.10.0"
3232

3333
[dev-dependencies]

src/adapt_strategy.rs

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ use crate::{
1616
state::State,
1717
stepsize::AcceptanceRateCollector,
1818
stepsize_adapt::{
19-
DualAverageSettings, Stats as StepSizeStats, StatsBuilder as StepSizeStatsBuilder,
20-
Strategy as StepSizeStrategy,
19+
DualAverageSettings, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy,
2120
},
2221
NutsError,
2322
};
@@ -63,20 +62,18 @@ impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
6362
}
6463

6564
impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<M, A> {
66-
type Stats = CombinedStats<StepSizeStats, A::Stats>;
67-
type Builder = CombinedStatsBuilder<StepSizeStatsBuilder, A::Builder>;
65+
type Builder = GlobalStrategyBuilder<A::Builder>;
66+
type StatOptions = <A as SamplerStats<M>>::StatOptions;
6867

69-
fn current_stats(&self, math: &mut M) -> Self::Stats {
70-
CombinedStats {
71-
stats1: self.step_size.current_stats(math),
72-
stats2: self.mass_matrix.current_stats(math),
73-
}
74-
}
75-
76-
fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder {
77-
CombinedStatsBuilder {
78-
stats1: SamplerStats::<M>::new_builder(&self.step_size, settings, dim),
79-
stats2: self.mass_matrix.new_builder(settings, dim),
68+
fn new_builder(
69+
&self,
70+
options: Self::StatOptions,
71+
settings: &impl Settings,
72+
dim: usize,
73+
) -> Self::Builder {
74+
GlobalStrategyBuilder {
75+
step_size: SamplerStats::<M>::new_builder(&self.step_size, (), settings, dim),
76+
mass_matrix: self.mass_matrix.new_builder(options, settings, dim),
8077
}
8178
}
8279
}
@@ -218,33 +215,37 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
218215
fn is_tuning(&self) -> bool {
219216
self.tuning
220217
}
221-
}
222218

223-
#[derive(Debug, Clone)]
224-
pub struct CombinedStats<D1, D2> {
225-
pub stats1: D1,
226-
pub stats2: D2,
219+
fn last_num_steps(&self) -> u64 {
220+
self.step_size.last_n_steps
221+
}
227222
}
228223

229-
#[derive(Clone)]
230-
pub struct CombinedStatsBuilder<B1, B2> {
231-
pub stats1: B1,
232-
pub stats2: B2,
224+
pub struct GlobalStrategyBuilder<B> {
225+
pub step_size: StepSizeStatsBuilder,
226+
pub mass_matrix: B,
233227
}
234228

235-
impl<S1, S2, B1, B2> StatTraceBuilder<CombinedStats<S1, S2>> for CombinedStatsBuilder<B1, B2>
229+
impl<M: Math, A> StatTraceBuilder<M, GlobalStrategy<M, A>> for GlobalStrategyBuilder<A::Builder>
236230
where
237-
B1: StatTraceBuilder<S1>,
238-
B2: StatTraceBuilder<S2>,
231+
A: MassMatrixAdaptStrategy<M>,
239232
{
240-
fn append_value(&mut self, value: CombinedStats<S1, S2>) {
241-
self.stats1.append_value(value.stats1);
242-
self.stats2.append_value(value.stats2);
233+
fn append_value(&mut self, math: Option<&mut M>, value: &GlobalStrategy<M, A>) {
234+
let math = math.expect("Smapler stats need math");
235+
self.step_size.append_value(Some(math), &value.step_size);
236+
self.mass_matrix
237+
.append_value(Some(math), &value.mass_matrix);
243238
}
244239

245240
fn finalize(self) -> Option<StructArray> {
246-
let Self { stats1, stats2 } = self;
247-
match (stats1.finalize(), stats2.finalize()) {
241+
let Self {
242+
step_size,
243+
mass_matrix,
244+
} = self;
245+
match (
246+
StatTraceBuilder::<M, _>::finalize(step_size),
247+
mass_matrix.finalize(),
248+
) {
248249
(None, None) => None,
249250
(Some(stats1), None) => Some(stats1),
250251
(None, Some(stats2)) => Some(stats2),
@@ -266,8 +267,14 @@ where
266267
}
267268

268269
fn inspect(&self) -> Option<StructArray> {
269-
let Self { stats1, stats2 } = self;
270-
match (stats1.inspect(), stats2.inspect()) {
270+
let Self {
271+
step_size,
272+
mass_matrix,
273+
} = self;
274+
match (
275+
StatTraceBuilder::<M, _>::inspect(step_size),
276+
mass_matrix.inspect(),
277+
) {
271278
(None, None) => None,
272279
(Some(stats1), None) => Some(stats1),
273280
(None, Some(stats2)) => Some(stats2),
@@ -374,6 +381,7 @@ pub mod test_logps {
374381

375382
#[derive(Error, Debug)]
376383
pub enum NormalLogpError {}
384+
377385
impl LogpError for NormalLogpError {
378386
fn is_recoverable(&self) -> bool {
379387
false
@@ -438,6 +446,7 @@ pub mod test_logps {
438446
_rng: &mut R,
439447
_untransformed_positions: impl Iterator<Item = &'a [f64]>,
440448
_untransformed_gradients: impl Iterator<Item = &'a [f64]>,
449+
_untransformed_logp: impl Iterator<Item = &'a f64>,
441450
_params: &'a mut Self::TransformParams,
442451
) -> Result<(), Self::LogpError> {
443452
unimplemented!()

0 commit comments

Comments
 (0)