Skip to content

Commit c4698fb

Browse files
committed
feature: add csv file storage backend
1 parent 36a0302 commit c4698fb

File tree

2 files changed

+1301
-0
lines changed

2 files changed

+1301
-0
lines changed

examples/csv_trace.rs

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
//! CSV backend example for MCMC trace storage
2+
//!
3+
//! This example demonstrates how to use the nuts-rs library with CSV storage
4+
//! for running MCMC sampling on a multivariate normal distribution. It shows:
5+
//!
6+
//! - Setting up a custom probability model
7+
//! - Configuring CSV storage for results
8+
//! - Running multiple parallel chains
9+
//! - Monitoring progress during sampling
10+
//! - Saving results in CmdStan-compatible CSV format
11+
12+
use std::{
13+
collections::HashMap,
14+
f64,
15+
time::{Duration, Instant},
16+
};
17+
18+
use anyhow::Result;
19+
use nuts_rs::{
20+
CpuLogpFunc, CpuMath, CpuMathError, CsvConfig, DiagGradNutsSettings, LogpError, Model, Sampler,
21+
SamplerWaitResult, Storable,
22+
};
23+
use nuts_storable::{HasDims, Value};
24+
use rand::Rng;
25+
use thiserror::Error;
26+
27+
/// A multivariate normal distribution model
28+
///
29+
/// This represents a probability distribution with mean μ and precision matrix P,
30+
/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ)
31+
#[derive(Clone, Debug)]
32+
struct MultivariateNormal {
33+
mean: Vec<f64>,
34+
precision: Vec<Vec<f64>>, // Inverse of covariance matrix
35+
}
36+
37+
impl MultivariateNormal {
38+
fn new(mean: Vec<f64>, precision: Vec<Vec<f64>>) -> Self {
39+
Self { mean, precision }
40+
}
41+
}
42+
43+
/// Custom error type for log probability calculations
44+
///
45+
/// MCMC samplers need to distinguish between recoverable errors (like numerical
46+
/// issues that can be handled by rejecting the proposal) and non-recoverable
47+
/// errors (like programming bugs that should stop sampling).
48+
#[allow(dead_code)]
49+
#[derive(Debug, Error)]
50+
enum MyLogpError {
51+
#[error("Recoverable error in logp calculation: {0}")]
52+
Recoverable(String),
53+
#[error("Non-recoverable error in logp calculation: {0}")]
54+
NonRecoverable(String),
55+
}
56+
57+
impl LogpError for MyLogpError {
58+
fn is_recoverable(&self) -> bool {
59+
matches!(self, MyLogpError::Recoverable(_))
60+
}
61+
}
62+
63+
/// Implementation of the log probability function for multivariate normal
64+
///
65+
/// This struct contains the model parameters and implements the mathematical
66+
/// operations needed for MCMC sampling: computing log probability and gradients.
67+
#[derive(Clone)]
68+
struct MvnLogp {
69+
model: MultivariateNormal,
70+
}
71+
72+
impl HasDims for MvnLogp {
73+
/// Define dimension names and sizes for storage
74+
///
75+
/// This tells the storage system what array dimensions to expect.
76+
/// These dimensions will be used to structure the output data.
77+
fn dim_sizes(&self) -> HashMap<String, u64> {
78+
HashMap::from([
79+
// Dimension for the parameter vector (for CSV storage)
80+
("param".to_string(), self.model.mean.len() as u64),
81+
])
82+
}
83+
84+
fn coords(&self) -> HashMap<String, nuts_storable::Value> {
85+
HashMap::from([(
86+
"param".to_string(),
87+
Value::Strings(vec!["mu1".to_string(), "mu2".to_string()]),
88+
)])
89+
}
90+
}
91+
92+
/// Additional quantities computed from each sample
93+
///
94+
/// The `Storable` derive macro automatically generates code to store this
95+
/// struct in the trace. The `dims` attribute specifies which dimension
96+
/// each field should use.
97+
#[derive(Storable)]
98+
struct ExpandedDraw {
99+
/// Store the parameter values as a vector with dimension "param"
100+
#[storable(dims("param"))]
101+
parameters: Vec<f64>,
102+
}
103+
104+
impl CpuLogpFunc for MvnLogp {
105+
type LogpError = MyLogpError;
106+
type FlowParameters = (); // No parameter transformations needed
107+
type ExpandedVector = ExpandedDraw;
108+
109+
/// Return the dimensionality of the parameter space
110+
fn dim(&self) -> usize {
111+
self.model.mean.len()
112+
}
113+
114+
/// Compute log probability and gradient
115+
///
116+
/// This is the core mathematical function that MCMC uses to explore
117+
/// the parameter space. It computes both the log probability density
118+
/// and its gradient for efficient sampling with Hamiltonian Monte Carlo.
119+
fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
120+
let n = x.len();
121+
122+
// Compute (x - mean)
123+
let mut diff = vec![0.0; n];
124+
for i in 0..n {
125+
diff[i] = x[i] - self.model.mean[i];
126+
}
127+
128+
let mut quad = 0.0;
129+
130+
// Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff
131+
for i in 0..n {
132+
// Compute i-th component of P * diff
133+
let mut pdot = 0.0;
134+
for j in 0..n {
135+
let pij = self.model.precision[i][j];
136+
pdot += pij * diff[j];
137+
quad += diff[i] * pij * diff[j];
138+
}
139+
// Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i
140+
grad[i] = -pdot;
141+
}
142+
143+
Ok(-0.5 * quad)
144+
}
145+
146+
/// Compute additional quantities from each sample
147+
///
148+
/// This function is called for each accepted sample to compute derived
149+
/// quantities that should be stored in the trace. These might be
150+
/// transformed parameters, predictions, or other quantities of interest.
151+
fn expand_vector<R: Rng + ?Sized>(
152+
&mut self,
153+
_rng: &mut R,
154+
array: &[f64],
155+
) -> Result<Self::ExpandedVector, CpuMathError> {
156+
// Return the parameter vector for CSV storage
157+
Ok(ExpandedDraw {
158+
parameters: array.to_vec(),
159+
})
160+
}
161+
162+
fn vector_coord(&self) -> Option<Value> {
163+
Some(Value::Strings(vec!["mu1".to_string(), "mu2".to_string()]))
164+
}
165+
}
166+
167+
/// The complete MCMC model
168+
///
169+
/// This struct implements the Model trait, which is the main interface
170+
/// that samplers use. It provides access to the mathematical operations
171+
/// and handles initialization of the sampling chains.
172+
struct MvnModel {
173+
math: CpuMath<MvnLogp>,
174+
}
175+
176+
impl Model for MvnModel {
177+
type Math<'model>
178+
= CpuMath<MvnLogp>
179+
where
180+
Self: 'model;
181+
182+
fn math(&self) -> Result<Self::Math<'_>> {
183+
Ok(self.math.clone())
184+
}
185+
186+
/// Generate random initial positions for the chain
187+
///
188+
/// Good initialization is important for MCMC efficiency. The starting
189+
/// points should be in a reasonable region of the parameter space
190+
/// where the log probability is finite.
191+
fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
192+
// Initialize each parameter randomly in the range [-2, 2]
193+
// For this simple example, this should put us in a reasonable
194+
// region around the mode of the distribution
195+
for p in position.iter_mut() {
196+
*p = rng.random_range(-2.0..2.0);
197+
}
198+
Ok(())
199+
}
200+
}
201+
202+
fn main() -> Result<()> {
203+
println!("=== Multivariate Normal MCMC with CSV Storage ===\n");
204+
205+
// Create a 2D multivariate normal distribution
206+
// This creates a distribution with mean [0, 0] and precision matrix
207+
// [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation
208+
// between the two parameters
209+
let mean = vec![0.0, 0.0];
210+
let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
211+
let mvn = MultivariateNormal::new(mean, precision);
212+
213+
println!("Model: 2D Multivariate Normal");
214+
println!("Mean: {:?}", mvn.mean);
215+
println!("Precision matrix: {:?}\n", mvn.precision);
216+
217+
// Configure output location
218+
let output_path = "csv_output";
219+
println!("Output will be saved to: {}/\n", output_path);
220+
221+
// Sampling configuration
222+
let num_chains = 4; // Run 4 parallel chains for better convergence assessment
223+
let num_tune = 500; // Warmup samples to tune the sampler
224+
let num_draws = 500; // Post-warmup samples to keep
225+
226+
println!("Sampling configuration:");
227+
println!(" Chains: {}", num_chains);
228+
println!(" Warmup samples: {}", num_tune);
229+
println!(" Sampling draws: {}", num_draws);
230+
231+
// Configure MCMC settings
232+
// DiagGradNutsSettings provides sensible defaults for the NUTS sampler
233+
let mut settings = DiagGradNutsSettings::default();
234+
settings.num_chains = num_chains as _;
235+
settings.num_tune = num_tune;
236+
settings.num_draws = num_draws as _;
237+
settings.seed = 54; // For reproducible results
238+
239+
// Set up CSV storage
240+
// This will create one CSV file per chain in the specified directory
241+
let csv_config = CsvConfig::new(output_path)
242+
.with_precision(6) // 6 decimal places for floating point values
243+
.store_warmup(true); // Include warmup samples with negative sample IDs
244+
245+
// Create the model instance
246+
let model = MvnModel {
247+
math: CpuMath::new(MvnLogp { model: mvn }),
248+
};
249+
250+
// Start sampling
251+
println!("\nStarting MCMC sampling...\n");
252+
let start = Instant::now();
253+
254+
// Create sampler with 4 worker threads
255+
// The sampler runs asynchronously, so we can monitor progress
256+
let mut sampler = Some(Sampler::new(model, settings, csv_config, 4, None)?);
257+
258+
let mut num_progress_updates = 0;
259+
260+
// Main sampling loop with progress monitoring
261+
// This demonstrates how to monitor long-running sampling jobs
262+
while let Some(sampler_) = sampler.take() {
263+
match sampler_.wait_timeout(Duration::from_millis(50)) {
264+
// Sampling completed successfully
265+
SamplerWaitResult::Trace(_) => {
266+
println!("✓ Sampling completed in {:?}", start.elapsed());
267+
println!("✓ Traces written to CSV format in '{}'", output_path);
268+
269+
// List the output files
270+
if let Ok(entries) = std::fs::read_dir(output_path) {
271+
println!("\nOutput files:");
272+
for entry in entries.flatten() {
273+
if let Some(name) = entry.file_name().to_str() {
274+
if name.ends_with(".csv") {
275+
println!(" - {}", name);
276+
}
277+
}
278+
}
279+
}
280+
281+
// Provide instructions for analysis
282+
println!("\n=== Next Steps ===");
283+
println!("The CSV files are compatible with CmdStan format and can be read by:");
284+
println!(" - R: posterior package, bayesplot, etc.");
285+
println!(" - Python: arviz.from_cmdstanpy() or pandas.read_csv()");
286+
println!(" - Stan ecosystem tools");
287+
println!("\nExample usage in Python:");
288+
println!(" import pandas as pd");
289+
println!(" import arviz as az");
290+
println!(" # Read individual chains");
291+
println!(" chain0 = pd.read_csv('{}/chain_0.csv')", output_path);
292+
println!(" # Or use arviz to read all chains");
293+
println!(" # (Note: arviz.from_cmdstanpy might need adaptation)");
294+
println!("\nExample usage in R:");
295+
println!(" library(posterior)");
296+
println!(" draws <- read_cmdstan_csv(c(");
297+
for i in 0..num_chains {
298+
let comma = if i == num_chains - 1 { "" } else { "," };
299+
println!(" '{}/chain_{}.csv'{}", output_path, i, comma);
300+
}
301+
println!(" ))");
302+
println!(" summarise_draws(draws)");
303+
break;
304+
}
305+
306+
// Timeout - sampler is still running, show progress
307+
SamplerWaitResult::Timeout(mut sampler_) => {
308+
num_progress_updates += 1;
309+
println!("Progress update {}:", num_progress_updates);
310+
311+
// Get current progress from all chains
312+
let progress = sampler_.progress()?;
313+
for (i, chain) in progress.iter().enumerate() {
314+
let phase = if chain.tuning { "warmup" } else { "sampling" };
315+
println!(
316+
" Chain {}: {} samples ({} divergences), step size: {:.6} [{}]",
317+
i, chain.finished_draws, chain.divergences, chain.step_size, phase
318+
);
319+
}
320+
println!(); // Add blank line for readability
321+
322+
sampler = Some(sampler_);
323+
}
324+
325+
// An error occurred during sampling
326+
SamplerWaitResult::Err(err, _) => {
327+
eprintln!("✗ Sampling failed: {}", err);
328+
return Err(err);
329+
}
330+
}
331+
}
332+
333+
Ok(())
334+
}

0 commit comments

Comments
 (0)