Skip to content

Commit 2a55602

Browse files
committed
feat: Add sampler stats for points
1 parent c416c42 commit 2a55602

File tree

9 files changed

+180
-31
lines changed

9 files changed

+180
-31
lines changed

src/chain.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ where
4444
draw_count: u64,
4545
strategy: A,
4646
math: M,
47-
stats: Option<NutsSampleStats<<A::Hamiltonian as SamplerStats<M>>::Stats, A::Stats>>,
47+
stats: Option<
48+
NutsSampleStats<
49+
<<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
50+
<A::Hamiltonian as SamplerStats<M>>::Stats,
51+
A::Stats,
52+
>,
53+
>,
4854
}
4955

5056
impl<M, R, A> NutsChain<M, R, A>
@@ -117,17 +123,22 @@ where
117123
A: AdaptStrategy<M>,
118124
{
119125
type Builder = NutsStatsBuilder<
126+
<<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Builder,
120127
<A::Hamiltonian as SamplerStats<M>>::Builder,
121128
<A as SamplerStats<M>>::Builder,
122129
>;
123-
type Stats =
124-
NutsSampleStats<<A::Hamiltonian as SamplerStats<M>>::Stats, <A as SamplerStats<M>>::Stats>;
130+
type Stats = NutsSampleStats<
131+
<<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
132+
<A::Hamiltonian as SamplerStats<M>>::Stats,
133+
<A as SamplerStats<M>>::Stats,
134+
>;
125135

126136
fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder {
127137
NutsStatsBuilder::new_with_capacity(
128138
settings,
129139
&self.hamiltonian,
130140
&self.strategy,
141+
self.init.point(),
131142
dim,
132143
&self.options,
133144
)
@@ -182,6 +193,7 @@ where
182193
draw: self.draw_count,
183194
potential_stats: self.hamiltonian.current_stats(&mut self.math),
184195
strategy_stats: self.strategy.current_stats(&mut self.math),
196+
point_stats: state.point().current_stats(&mut self.math),
185197
gradient: if self.options.store_gradient {
186198
let mut gradient: Box<[f64]> = vec![0f64; self.math.dim()].into();
187199
state.write_gradient(&mut self.math, &mut gradient);

src/cpu_math.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
128128
)
129129
}
130130

131+
fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
132+
x.as_slice()
133+
.iter()
134+
.zip(y.as_slice())
135+
.map(|(&x, &y)| (x + y) * (x + y))
136+
.sum()
137+
}
138+
131139
fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
132140
dest.as_slice_mut().copy_from_slice(source);
133141
}

src/euclidean_hamiltonian.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,40 @@ pub struct EuclideanPoint<M: Math> {
5252
pub initial_energy: f64,
5353
}
5454

55+
#[derive(Clone, Debug)]
56+
pub struct PointStats {}
57+
58+
pub struct PointStatsBuilder {}
59+
60+
impl StatTraceBuilder<PointStats> for PointStatsBuilder {
61+
fn append_value(&mut self, value: PointStats) {
62+
let PointStats {} = value;
63+
}
64+
65+
fn finalize(self) -> Option<StructArray> {
66+
let Self {} = self;
67+
None
68+
}
69+
70+
fn inspect(&self) -> Option<StructArray> {
71+
let Self {} = self;
72+
None
73+
}
74+
}
75+
76+
impl<M: Math> SamplerStats<M> for EuclideanPoint<M> {
77+
type Stats = PointStats;
78+
type Builder = PointStatsBuilder;
79+
80+
fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {
81+
Self::Builder {}
82+
}
83+
84+
fn current_stats(&self, _math: &mut M) -> Self::Stats {
85+
PointStats {}
86+
}
87+
}
88+
5589
impl<M: Math> EuclideanPoint<M> {
5690
fn is_turning(&self, math: &mut M, other: &Self) -> bool {
5791
let (start, end) = if self.index_in_trajectory() < other.index_in_trajectory() {

src/hamiltonian.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub enum LeapfrogResult<M: Math, P: Point<M>> {
4848
Err(M::LogpErr),
4949
}
5050

51-
pub trait Point<M: Math>: Sized {
51+
pub trait Point<M: Math>: Sized + SamplerStats<M> {
5252
fn position(&self) -> &M::Vector;
5353
fn gradient(&self) -> &M::Vector;
5454
fn index_in_trajectory(&self) -> i64;

src/math_base.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ pub trait Math {
6464
y: &Self::Vector,
6565
) -> (f64, f64);
6666

67+
fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;
68+
6769
fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
6870
fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
6971
fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;

src/nuts.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ where
310310

311311
#[derive(Debug, Clone)]
312312
#[non_exhaustive]
313-
pub struct NutsSampleStats<HStats: Send + Debug + Clone, AdaptStats: Send + Debug + Clone> {
313+
pub struct NutsSampleStats<
314+
PointStats: Send + Debug + Clone,
315+
HStats: Send + Debug + Clone,
316+
AdaptStats: Send + Debug + Clone,
317+
> {
314318
pub depth: u64,
315319
pub maxdepth_reached: bool,
316320
pub idx_in_trajectory: i64,
@@ -324,6 +328,7 @@ pub struct NutsSampleStats<HStats: Send + Debug + Clone, AdaptStats: Send + Debu
324328
pub unconstrained: Option<Box<[f64]>>,
325329
pub potential_stats: HStats,
326330
pub strategy_stats: AdaptStats,
331+
pub point_stats: PointStats,
327332
pub tuning: bool,
328333
}
329334

@@ -338,7 +343,7 @@ pub struct SampleStats {
338343
pub num_steps: u64,
339344
}
340345

341-
pub struct NutsStatsBuilder<H, A> {
346+
pub struct NutsStatsBuilder<P, H, A> {
342347
depth: PrimitiveBuilder<UInt64Type>,
343348
maxdepth_reached: BooleanBuilder,
344349
index_in_trajectory: PrimitiveBuilder<Int64Type>,
@@ -351,6 +356,7 @@ pub struct NutsStatsBuilder<H, A> {
351356
gradient: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
352357
hamiltonian: H,
353358
adapt: A,
359+
point: P,
354360
diverging: BooleanBuilder,
355361
divergence_start: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
356362
divergence_start_grad: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
@@ -360,15 +366,17 @@ pub struct NutsStatsBuilder<H, A> {
360366
n_dim: usize,
361367
}
362368

363-
impl<HB, AB> NutsStatsBuilder<HB, AB> {
369+
impl<PB, HB, AB> NutsStatsBuilder<PB, HB, AB> {
364370
pub fn new_with_capacity<
365371
M: Math,
366-
H: Hamiltonian<M, Builder = HB>,
372+
P: Point<M, Builder = PB>,
373+
H: Hamiltonian<M, Builder = HB, Point = P>,
367374
A: AdaptStrategy<M, Builder = AB>,
368375
>(
369376
settings: &impl Settings,
370377
hamiltonian: &H,
371378
adapt: &A,
379+
point: &P,
372380
dim: usize,
373381
options: &NutsOptions,
374382
) -> Self {
@@ -430,6 +438,7 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
430438
unconstrained,
431439
hamiltonian: hamiltonian.new_builder(settings, dim),
432440
adapt: adapt.new_builder(settings, dim),
441+
point: point.new_builder(settings, dim),
433442
diverging: BooleanBuilder::with_capacity(capacity),
434443
divergence_start: div_start,
435444
divergence_start_grad: div_start_grad,
@@ -441,14 +450,17 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
441450
}
442451
}
443452

444-
impl<HS, AS, HB, AB> StatTraceBuilder<NutsSampleStats<HS, AS>> for NutsStatsBuilder<HB, AB>
453+
impl<PS, HS, AS, PB, HB, AB> StatTraceBuilder<NutsSampleStats<PS, HS, AS>>
454+
for NutsStatsBuilder<PB, HB, AB>
445455
where
446456
HB: StatTraceBuilder<HS>,
447457
AB: StatTraceBuilder<AS>,
458+
PB: StatTraceBuilder<PS>,
448459
HS: Clone + Send + Debug,
449460
AS: Clone + Send + Debug,
461+
PS: Clone + Send + Debug,
450462
{
451-
fn append_value(&mut self, value: NutsSampleStats<HS, AS>) {
463+
fn append_value(&mut self, value: NutsSampleStats<PS, HS, AS>) {
452464
let NutsSampleStats {
453465
depth,
454466
maxdepth_reached,
@@ -463,6 +475,7 @@ where
463475
unconstrained,
464476
potential_stats,
465477
strategy_stats,
478+
point_stats,
466479
tuning,
467480
} = value;
468481

@@ -532,6 +545,7 @@ where
532545

533546
self.hamiltonian.append_value(potential_stats);
534547
self.adapt.append_value(strategy_stats);
548+
self.point.append_value(point_stats);
535549
}
536550

537551
fn finalize(self) -> Option<StructArray> {
@@ -548,6 +562,7 @@ where
548562
gradient,
549563
hamiltonian,
550564
adapt,
565+
point,
551566
mut diverging,
552567
divergence_start,
553568
divergence_start_grad,
@@ -615,6 +630,7 @@ where
615630

616631
merge_into(hamiltonian, &mut arrays, &mut fields);
617632
merge_into(adapt, &mut arrays, &mut fields);
633+
merge_into(point, &mut arrays, &mut fields);
618634

619635
add_field(gradient, "gradient", &mut arrays, &mut fields);
620636
add_field(
@@ -667,6 +683,7 @@ where
667683
gradient,
668684
hamiltonian,
669685
adapt,
686+
point,
670687
diverging,
671688
divergence_start,
672689
divergence_start_grad,
@@ -734,6 +751,7 @@ where
734751

735752
merge_into(hamiltonian, &mut arrays, &mut fields);
736753
merge_into(adapt, &mut arrays, &mut fields);
754+
merge_into(point, &mut arrays, &mut fields);
737755

738756
add_field(gradient, "gradient", &mut arrays, &mut fields);
739757
add_field(

src/sampler.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ pub mod test_logps {
949949
}
950950
}
951951

952-
impl<'a> CpuLogpFunc for &'a NormalLogp {
952+
impl CpuLogpFunc for &NormalLogp {
953953
type LogpError = NormalLogpError;
954954
type TransformParams = ();
955955

@@ -981,7 +981,7 @@ pub mod test_logps {
981981
for (p, g) in pos.chunks_exact(4).zip(grad.chunks_exact_mut(4)) {
982982
let p = f64x4::from_slice(p);
983983
let val = mu_splat - p;
984-
logp = logp - val * val * f64x4::splat(0.5);
984+
logp = val * val * f64x4::splat(0.5);
985985
g.copy_from_slice(&val.to_array());
986986
}
987987

src/transform_adapt_strategy.rs

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,38 +108,56 @@ impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
108108
math: &mut M,
109109
_start: &State<M, P>,
110110
end: &State<M, P>,
111-
_divergence_info: Option<&crate::DivergenceInfo>,
111+
divergence_info: Option<&crate::DivergenceInfo>,
112112
) {
113+
if divergence_info.is_some() {
114+
return;
115+
}
116+
113117
if self.collect_orbit {
114118
let point = end.point();
115119
let energy_error = point.energy_error();
116-
if energy_error.abs() < self.max_energy_error {
117-
if !math.array_all_finite(point.position()) {
118-
return;
119-
}
120-
if !math.array_all_finite(point.gradient()) {
121-
return;
122-
}
123-
self.draws.push(math.copy_array(point.position()));
124-
self.grads.push(math.copy_array(point.gradient()));
120+
if !energy_error.is_finite() {
121+
return;
122+
}
123+
124+
if energy_error > self.max_energy_error {
125+
return;
126+
}
127+
128+
if !math.array_all_finite(point.position()) {
129+
return;
130+
}
131+
if !math.array_all_finite(point.gradient()) {
132+
return;
125133
}
134+
135+
self.draws.push(math.copy_array(point.position()));
136+
self.grads.push(math.copy_array(point.gradient()));
126137
}
127138
}
128139

129140
fn register_draw(&mut self, math: &mut M, state: &State<M, P>, _info: &SampleInfo) {
130141
if !self.collect_orbit {
131142
let point = state.point();
132143
let energy_error = point.energy_error();
133-
if energy_error.abs() < self.max_energy_error {
134-
if !math.array_all_finite(point.position()) {
135-
return;
136-
}
137-
if !math.array_all_finite(point.gradient()) {
138-
return;
139-
}
140-
self.draws.push(math.copy_array(point.position()));
141-
self.grads.push(math.copy_array(point.gradient()));
144+
if !energy_error.is_finite() {
145+
return;
146+
}
147+
148+
if energy_error > self.max_energy_error {
149+
return;
142150
}
151+
152+
if !math.array_all_finite(point.position()) {
153+
return;
154+
}
155+
if !math.array_all_finite(point.gradient()) {
156+
return;
157+
}
158+
159+
self.draws.push(math.copy_array(point.position()));
160+
self.grads.push(math.copy_array(point.gradient()));
143161
}
144162
}
145163
}

0 commit comments

Comments
 (0)