Skip to content

Commit 46203f6

Browse files
committed
feat: expose sample data via ProgressCallback
Extend ChainProgress with optional latest_sample field instead of adding separate SampleCallback. Consolidates callback API while maintaining access to per-sample data.
1 parent 265ea63 commit 46203f6

File tree

5 files changed

+553
-24
lines changed

5 files changed

+553
-24
lines changed

examples/sample_callback.rs

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
//! Example demonstrating sample-level data access via ProgressCallback
2+
//!
3+
//! This example shows how to access per-sample data through the existing
4+
//! ProgressCallback using the `latest_sample` field in ChainProgress.
5+
6+
use std::{
7+
f64,
8+
sync::{Arc, Mutex},
9+
time::Duration,
10+
};
11+
12+
use anyhow::Result;
13+
use nuts_rs::{
14+
CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, HashMapConfig, LogpError, Model,
15+
ProgressCallback, Sampler,
16+
};
17+
use nuts_storable::HasDims;
18+
use rand::{Rng, RngExt};
19+
use thiserror::Error;
20+
21+
// A simple multivariate normal distribution example
22+
#[derive(Clone, Debug)]
23+
struct MultivariateNormal {
24+
mean: Vec<f64>,
25+
precision: Vec<Vec<f64>>,
26+
}
27+
28+
impl MultivariateNormal {
29+
fn new(mean: Vec<f64>, precision: Vec<Vec<f64>>) -> Self {
30+
Self { mean, precision }
31+
}
32+
}
33+
34+
// Custom LogpError implementation
35+
#[allow(dead_code)]
36+
#[derive(Debug, Error)]
37+
enum MyLogpError {
38+
#[error("Recoverable error in logp calculation: {0}")]
39+
Recoverable(String),
40+
#[error("Non-recoverable error in logp calculation: {0}")]
41+
NonRecoverable(String),
42+
}
43+
44+
impl LogpError for MyLogpError {
45+
fn is_recoverable(&self) -> bool {
46+
matches!(self, MyLogpError::Recoverable(_))
47+
}
48+
}
49+
50+
// Implementation of the model's logp function
51+
#[derive(Clone)]
52+
struct MvnLogp {
53+
model: MultivariateNormal,
54+
}
55+
56+
impl HasDims for MvnLogp {
57+
fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
58+
std::collections::HashMap::from([
59+
(
60+
"unconstrained_parameter".to_string(),
61+
self.model.mean.len() as u64,
62+
),
63+
("dim".to_string(), self.model.mean.len() as u64),
64+
])
65+
}
66+
}
67+
68+
impl CpuLogpFunc for MvnLogp {
69+
type LogpError = MyLogpError;
70+
type FlowParameters = ();
71+
type ExpandedVector = Vec<f64>;
72+
73+
fn dim(&self) -> usize {
74+
self.model.mean.len()
75+
}
76+
77+
fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
78+
let n = x.len();
79+
// Compute (x - mean)
80+
let mut diff = vec![0.0; n];
81+
for i in 0..n {
82+
diff[i] = x[i] - self.model.mean[i];
83+
}
84+
85+
let mut quad = 0.0;
86+
// Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff
87+
for i in 0..n {
88+
// Compute i-th component of P * diff
89+
let mut pdot = 0.0;
90+
for j in 0..n {
91+
let pij = self.model.precision[i][j];
92+
pdot += pij * diff[j];
93+
quad += diff[i] * pij * diff[j];
94+
}
95+
// gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is - (P * diff)_i
96+
grad[i] = -pdot;
97+
}
98+
99+
Ok(-0.5 * quad)
100+
}
101+
102+
fn expand_vector<R: Rng + ?Sized>(
103+
&mut self,
104+
_rng: &mut R,
105+
array: &[f64],
106+
) -> Result<Self::ExpandedVector, CpuMathError> {
107+
// Simply return the parameter values
108+
Ok(array.to_vec())
109+
}
110+
}
111+
112+
struct MvnModel {
113+
math: CpuMath<MvnLogp>,
114+
}
115+
116+
/// Implementation of Model for the HashMap backend
117+
impl Model for MvnModel {
118+
type Math<'model>
119+
= CpuMath<MvnLogp>
120+
where
121+
Self: 'model;
122+
123+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
124+
Ok(self.math.clone())
125+
}
126+
127+
/// Generate random initial positions for the chain
128+
fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
129+
// Initialize position randomly in [-2, 2]
130+
for p in position.iter_mut() {
131+
*p = rng.random_range(-2.0..2.0);
132+
}
133+
Ok(())
134+
}
135+
}
136+
137+
fn main() -> Result<()> {
138+
println!("=== Sample-Level Data via ProgressCallback Example ===\n");
139+
println!("This example demonstrates accessing per-sample data through ProgressCallback.");
140+
println!("The callback fires periodically (rate-limited to 10ms) with chain progress,");
141+
println!("including the latest sample data for each chain.\n");
142+
143+
// Create a 2D multivariate normal distribution
144+
let mean = vec![0.0, 0.0];
145+
let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
146+
let mvn = MultivariateNormal::new(mean, precision);
147+
148+
// Number of chains
149+
let num_chains = 2;
150+
151+
// Configure number of draws
152+
let num_tune = 50;
153+
let num_draws = 100;
154+
155+
// Configure MCMC settings
156+
let mut settings = DiagGradNutsSettings::default();
157+
settings.num_chains = num_chains as _;
158+
settings.num_tune = num_tune;
159+
settings.num_draws = num_draws as _;
160+
settings.seed = 42;
161+
162+
let model = MvnModel {
163+
math: CpuMath::new(MvnLogp { model: mvn }),
164+
};
165+
166+
// Track callback invocations for demonstration
167+
let callback_count = Arc::new(Mutex::new(0));
168+
let callback_count_clone = callback_count.clone();
169+
170+
let divergence_count = Arc::new(Mutex::new(0));
171+
let divergence_count_clone = divergence_count.clone();
172+
173+
// Create progress callback that accesses latest sample data
174+
let progress_callback = ProgressCallback {
175+
callback: Box::new(move |elapsed, chains| {
176+
let mut count = callback_count_clone.lock().unwrap();
177+
*count += 1;
178+
179+
// Print progress information periodically
180+
if *count <= 10 {
181+
println!(
182+
"📊 Progress callback #{}: Elapsed: {:.1}s, {} chains",
183+
count,
184+
elapsed.as_secs_f64(),
185+
chains.len()
186+
);
187+
188+
for chain_progress in chains.iter() {
189+
// Access the latest sample data if available
190+
if let Some(sample_data) = &chain_progress.latest_sample {
191+
println!(
192+
" Chain {}: Draw {}/{}, Energy: {:.3}, Diverging: {}, Tree depth: {}",
193+
sample_data.chain_id,
194+
chain_progress.finished_draws,
195+
chain_progress.total_draws,
196+
sample_data.energy,
197+
sample_data.diverging,
198+
sample_data.tree_depth
199+
);
200+
println!(
201+
" Position: [{:.4}, {:.4}]",
202+
sample_data.position[0], sample_data.position[1]
203+
);
204+
println!(
205+
" Step size: {:.6}, Tuning: {}",
206+
sample_data.step_size, sample_data.is_tuning
207+
);
208+
209+
// Track divergences
210+
if sample_data.diverging {
211+
let mut div_count = divergence_count_clone.lock().unwrap();
212+
*div_count += 1;
213+
}
214+
}
215+
}
216+
println!();
217+
} else if *count == 11 {
218+
println!(" ... (suppressing further callback output) ...\n");
219+
}
220+
}),
221+
rate: Duration::from_millis(10), // Rate limit: at most one callback per 10ms
222+
};
223+
224+
// Create a new sampler with the progress callback
225+
let trace_config = HashMapConfig::new();
226+
let mut sampler = Sampler::new(
227+
model,
228+
settings,
229+
trace_config,
230+
4, // num_cores
231+
Some(progress_callback), // progress callback with sample data access
232+
)?;
233+
234+
println!("Starting sampling with progress callback...\n");
235+
236+
// Wait for sampling to complete
237+
let traces = loop {
238+
match sampler.wait_timeout(std::time::Duration::from_millis(100)) {
239+
nuts_rs::SamplerWaitResult::Trace(traces) => break traces,
240+
nuts_rs::SamplerWaitResult::Timeout(s) => sampler = s,
241+
nuts_rs::SamplerWaitResult::Err(e, _) => return Err(e),
242+
}
243+
};
244+
245+
println!("\n=== Sampling Complete ===");
246+
println!(
247+
"Total callback invocations: {}",
248+
*callback_count.lock().unwrap()
249+
);
250+
println!(
251+
"Divergences detected via callback: {}",
252+
*divergence_count.lock().unwrap()
253+
);
254+
println!("Number of chains: {}", traces.len());
255+
256+
// Show some basic statistics from the traces
257+
for (chain_idx, chain_result) in traces.iter().enumerate() {
258+
println!("\nChain {}:", chain_idx);
259+
260+
// Count divergences from stats
261+
if let Some(nuts_rs::HashMapValue::Bool(divergences)) = chain_result.stats.get("diverging")
262+
{
263+
let div_count = divergences.iter().filter(|&&d| d).count();
264+
println!(" Divergences in trace: {}", div_count);
265+
}
266+
267+
// Calculate mean position
268+
if let Some(nuts_rs::HashMapValue::F64(positions)) = chain_result.draws.get("theta") {
269+
if positions.len() >= 2 {
270+
let x_mean: f64 =
271+
positions.iter().step_by(2).sum::<f64>() / (positions.len() / 2) as f64;
272+
let y_mean: f64 =
273+
positions.iter().skip(1).step_by(2).sum::<f64>() / (positions.len() / 2) as f64;
274+
println!(" Mean position: [{:.4}, {:.4}]", x_mean, y_mean);
275+
}
276+
}
277+
}
278+
279+
println!("\n✅ Example completed successfully!");
280+
println!("\nKey features demonstrated:");
281+
println!(" - ProgressCallback provides both chain progress and latest sample data");
282+
println!(" - Time-based rate limiting (10ms) prevents excessive overhead");
283+
println!(" - latest_sample includes rich data (position, energy, divergence, etc.)");
284+
println!(" - Works seamlessly with multi-chain sampling");
285+
println!(" - Single callback mechanism for all monitoring needs");
286+
287+
Ok(())
288+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ pub use model::Model;
129129
pub use nuts::NutsError;
130130
pub use sampler::{
131131
ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress,
132-
ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings,
132+
ProgressCallback, SampleData, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings,
133133
sample_sequentially,
134134
};
135135
pub use sampler_stats::SamplerStats;

0 commit comments

Comments
 (0)