Skip to content

Commit 8893d2b

Browse files
committed
feat: implement arrow storage
1 parent 2fe6539 commit 8893d2b

File tree

5 files changed

+995
-0
lines changed

5 files changed

+995
-0
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ zarrs = { version = "0.21.0", features = [
3434
"async",
3535
], optional = true }
3636
ndarray = { version = "0.16.1", optional = true }
37+
arrow = { version = "56.2.0", optional = true }
38+
arrow-schema = { version = "56.2.0", features = [
39+
"canonical_extension_types",
40+
], optional = true }
3741
nuts-derive = { path = "./nuts-derive" }
3842
nuts-storable = { path = "./nuts-storable" }
3943
serde = { version = "1.0.219", features = ["derive"] }
@@ -57,6 +61,7 @@ tokio = { version = "1.0", features = ["rt", "rt-multi-thread"] }
5761
[features]
5862
zarr = ["dep:zarrs", "dep:tokio"]
5963
ndarray = ["dep:ndarray"]
64+
arrow = ["dep:arrow", "dep:arrow-schema"]
6065

6166
[[bench]]
6267
name = "sample"

examples/arrow_trace.rs

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
//! Arrow backend example for MCMC trace storage
2+
//!
3+
//! This example demonstrates how to use the nuts-rs library with Arrow storage
4+
//! for running MCMC sampling on a multivariate normal distribution. It shows:
5+
//!
6+
//! - Setting up a custom probability model
7+
//! - Configuring Arrow storage for results
8+
//! - Running multiple parallel chains
9+
//! - Monitoring progress during sampling
10+
//! - Accessing results in Arrow/Parquet format
11+
12+
use std::{
13+
collections::HashMap,
14+
f64,
15+
time::{Duration, Instant},
16+
};
17+
18+
use anyhow::Result;
19+
use nuts_rs::{
20+
ArrowConfig, CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model,
21+
Sampler, SamplerWaitResult, Storable,
22+
};
23+
use nuts_storable::{HasDims, Value};
24+
use rand::Rng;
25+
use thiserror::Error;
26+
27+
/// A multivariate normal distribution model
28+
///
29+
/// This represents a probability distribution with mean μ and precision matrix P,
30+
/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ)
31+
#[derive(Clone, Debug)]
32+
struct MultivariateNormal {
33+
mean: Vec<f64>,
34+
precision: Vec<Vec<f64>>, // Inverse of covariance matrix
35+
}
36+
37+
impl MultivariateNormal {
38+
fn new(mean: Vec<f64>, precision: Vec<Vec<f64>>) -> Self {
39+
Self { mean, precision }
40+
}
41+
}
42+
43+
/// Custom error type for log probability calculations
44+
///
45+
/// MCMC samplers need to distinguish between recoverable errors (like numerical
46+
/// issues that can be handled by rejecting the proposal) and non-recoverable
47+
/// errors (like programming bugs that should stop sampling).
48+
#[allow(dead_code)]
49+
#[derive(Debug, Error)]
50+
enum MyLogpError {
51+
#[error("Recoverable error in logp calculation: {0}")]
52+
Recoverable(String),
53+
#[error("Non-recoverable error in logp calculation: {0}")]
54+
NonRecoverable(String),
55+
}
56+
57+
impl LogpError for MyLogpError {
58+
fn is_recoverable(&self) -> bool {
59+
matches!(self, MyLogpError::Recoverable(_))
60+
}
61+
}
62+
63+
/// Implementation of the log probability function for multivariate normal
64+
///
65+
/// This struct contains the model parameters and implements the mathematical
66+
/// operations needed for MCMC sampling: computing log probability and gradients.
67+
#[derive(Clone)]
68+
struct MvnLogp {
69+
model: MultivariateNormal,
70+
buffer: Vec<f64>, // Temporary buffer for computations
71+
}
72+
73+
impl HasDims for MvnLogp {
74+
/// Define dimension names and sizes for storage
75+
///
76+
/// This tells the storage system what array dimensions to expect.
77+
/// These dimensions will be used to structure the output data using
78+
/// Arrow's FixedShapeTensor extension type.
79+
fn dim_sizes(&self) -> HashMap<String, u64> {
80+
HashMap::from([
81+
// Dimension for the actual parameter vector x
82+
("x".to_string(), self.model.mean.len() as u64),
83+
])
84+
}
85+
86+
fn coords(&self) -> HashMap<String, nuts_storable::Value> {
87+
HashMap::from([(
88+
"x".to_string(),
89+
Value::Strings(vec!["x1".to_string(), "x2".to_string()]),
90+
)])
91+
}
92+
}
93+
94+
/// Additional quantities computed from each sample
95+
///
96+
/// The `Storable` derive macro automatically generates code to store this
97+
/// struct in the trace. The `dims` attribute specifies which dimension
98+
/// each field should use. Multi-dimensional fields will be stored as
99+
/// FixedShapeTensor extension types in Arrow format.
100+
#[derive(Storable)]
101+
struct ExpandedDraw {
102+
/// Store the parameter values with dimension "x"
103+
#[storable(dims("x"))]
104+
prec: Vec<f64>,
105+
/// A scalar derived quantity (difference between first two parameters)
106+
diff: f64,
107+
}
108+
109+
impl CpuLogpFunc for MvnLogp {
110+
type LogpError = MyLogpError;
111+
type FlowParameters = (); // No parameter transformations needed
112+
type ExpandedVector = ExpandedDraw;
113+
114+
/// Return the dimensionality of the parameter space
115+
fn dim(&self) -> usize {
116+
self.model.mean.len()
117+
}
118+
119+
/// Compute log probability and gradient
120+
///
121+
/// This is the core mathematical function that MCMC uses to explore
122+
/// the parameter space. It computes both the log probability density
123+
/// and its gradient for efficient sampling with Hamiltonian Monte Carlo.
124+
fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
125+
let n = x.len();
126+
127+
// Compute (x - mean)
128+
let diff = &mut self.buffer;
129+
for i in 0..n {
130+
diff[i] = x[i] - self.model.mean[i];
131+
}
132+
133+
let mut quad = 0.0;
134+
135+
// Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff
136+
for i in 0..n {
137+
// Compute i-th component of P * diff
138+
let mut pdot = 0.0;
139+
for j in 0..n {
140+
let pij = self.model.precision[i][j];
141+
pdot += pij * diff[j];
142+
quad += diff[i] * pij * diff[j];
143+
}
144+
// Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i
145+
grad[i] = -pdot;
146+
}
147+
148+
Ok(-0.5 * quad)
149+
}
150+
151+
/// Compute additional quantities from each sample
152+
///
153+
/// This function is called for each accepted sample to compute derived
154+
/// quantities that should be stored in the trace. These might be
155+
/// transformed parameters, predictions, or other quantities of interest.
156+
fn expand_vector<R: Rng + ?Sized>(
157+
&mut self,
158+
_rng: &mut R,
159+
array: &[f64],
160+
) -> Result<Self::ExpandedVector, CpuMathError> {
161+
// Store the raw parameter values and compute a simple derived quantity
162+
Ok(ExpandedDraw {
163+
prec: array.to_vec(),
164+
diff: array[1] - array[0], // Example: difference between first two parameters
165+
})
166+
}
167+
168+
fn vector_coord(&self) -> Option<Value> {
169+
Some(Value::Strings(vec!["x1".to_string(), "x2".to_string()]))
170+
}
171+
}
172+
173+
/// The complete MCMC model
174+
///
175+
/// This struct implements the Model trait, which is the main interface
176+
/// that samplers use. It provides access to the mathematical operations
177+
/// and handles initialization of the sampling chains.
178+
struct MvnModel {
179+
math: CpuMath<MvnLogp>,
180+
}
181+
182+
impl Model for MvnModel {
183+
type Math<'model>
184+
= CpuMath<MvnLogp>
185+
where
186+
Self: 'model;
187+
188+
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
189+
Ok(self.math.clone())
190+
}
191+
192+
/// Generate random initial positions for the chain
193+
///
194+
/// Good initialization is important for MCMC efficiency. The starting
195+
/// points should be in a reasonable region of the parameter space
196+
/// where the log probability is finite.
197+
fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
198+
// Initialize each parameter randomly in the range [-2, 2]
199+
// For this simple example, this should put us in a reasonable
200+
// region around the mode of the distribution
201+
for p in position.iter_mut() {
202+
*p = rng.random_range(-2.0..2.0);
203+
}
204+
Ok(())
205+
}
206+
}
207+
208+
fn main() -> Result<()> {
209+
println!("=== Multivariate Normal MCMC with Arrow Storage ===\n");
210+
211+
// Create a 2D multivariate normal distribution
212+
// This creates a distribution with mean [0, 0] and precision matrix
213+
// [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation
214+
// between the two parameters
215+
let mean = vec![0.0, 0.0];
216+
let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
217+
let mvn = MultivariateNormal::new(mean, precision);
218+
219+
println!("Model: 2D Multivariate Normal");
220+
println!("Mean: {:?}", mvn.mean);
221+
println!("Precision matrix: {:?}\n", mvn.precision);
222+
223+
// Configure output location
224+
let output_path = "mcmc_output";
225+
println!("Output will be saved to: {}/\n", output_path);
226+
227+
// Sampling configuration
228+
let num_chains = 4; // Run 4 parallel chains for better convergence assessment
229+
let num_tune = 500; // Warmup samples to tune the sampler
230+
let num_draws = 500; // Post-warmup samples to keep
231+
232+
println!("Sampling configuration:");
233+
println!(" Chains: {}", num_chains);
234+
println!(" Warmup samples: {}", num_tune);
235+
println!(" Sampling draws: {}", num_draws);
236+
237+
// Configure MCMC settings
238+
// DiagGradNutsSettings provides sensible defaults for the NUTS sampler
239+
let mut settings = DiagGradNutsSettings::default();
240+
settings.num_chains = num_chains as _;
241+
settings.num_tune = num_tune;
242+
settings.num_draws = num_draws as _;
243+
settings.seed = 54; // For reproducible results
244+
245+
// Create the model instance
246+
let model = MvnModel {
247+
math: CpuMath::new(MvnLogp {
248+
model: mvn,
249+
buffer: vec![0.0; 2],
250+
}),
251+
};
252+
253+
// Start sampling
254+
println!("\nStarting MCMC sampling...\n");
255+
let start = Instant::now();
256+
257+
// Configure Arrow storage - it automatically determines capacity from settings
258+
let arrow_config = ArrowConfig::new();
259+
260+
// Create sampler with 4 worker threads
261+
// The sampler runs asynchronously, so we can monitor progress
262+
let mut sampler = Some(Sampler::new(model, settings, arrow_config, 4, None)?);
263+
264+
let mut num_progress_updates = 0;
265+
266+
// Main sampling loop with progress monitoring
267+
// This demonstrates how to monitor long-running sampling jobs
268+
while let Some(sampler_) = sampler.take() {
269+
match sampler_.wait_timeout(Duration::from_millis(50)) {
270+
// Sampling completed successfully
271+
SamplerWaitResult::Trace(traces) => {
272+
println!("✓ Sampling completed in {:?}", start.elapsed());
273+
274+
// Display information about the resulting Arrow data
275+
println!("✓ MCMC traces stored in Arrow format");
276+
println!("\nTrace summary:");
277+
println!(" Number of chains: {}", traces.len());
278+
279+
if let Some(first_trace) = traces.first() {
280+
println!(
281+
" Posterior samples: {} rows, {} columns",
282+
first_trace.posterior.num_rows(),
283+
first_trace.posterior.num_columns()
284+
);
285+
println!(
286+
" Sample stats: {} rows, {} columns",
287+
first_trace.sample_stats.num_rows(),
288+
first_trace.sample_stats.num_columns()
289+
);
290+
291+
// Show column names
292+
println!("\n Posterior columns:");
293+
for field in first_trace.posterior.schema().fields() {
294+
println!(
295+
" {} ({} {:?})",
296+
field.name(),
297+
field.data_type(),
298+
field.metadata(),
299+
);
300+
}
301+
302+
println!("\n Sample stats columns:");
303+
for field in first_trace.sample_stats.schema().fields() {
304+
println!(
305+
" {} ({} {:?})",
306+
field.name(),
307+
field.data_type(),
308+
field.metadata(),
309+
);
310+
}
311+
}
312+
break;
313+
}
314+
315+
// Timeout - sampler is still running, show progress
316+
SamplerWaitResult::Timeout(mut sampler_) => {
317+
num_progress_updates += 1;
318+
println!("Progress update {}:", num_progress_updates);
319+
320+
// Get current progress from all chains
321+
let progress = sampler_.progress()?;
322+
for (i, chain) in progress.iter().enumerate() {
323+
let phase = if chain.tuning { "warmup" } else { "sampling" };
324+
println!(
325+
" Chain {}: {} samples ({} divergences), step size: {:.6} [{}]",
326+
i, chain.finished_draws, chain.divergences, chain.step_size, phase
327+
);
328+
}
329+
println!(); // Add blank line for readability
330+
331+
sampler = Some(sampler_);
332+
}
333+
334+
// An error occurred during sampling
335+
SamplerWaitResult::Err(err, _) => {
336+
eprintln!("✗ Sampling failed: {}", err);
337+
return Err(err);
338+
}
339+
}
340+
}
341+
342+
Ok(())
343+
}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,6 @@ pub use storage::{CsvConfig, CsvTraceStorage};
145145
pub use storage::{HashMapConfig, HashMapValue};
146146
#[cfg(feature = "ndarray")]
147147
pub use storage::{NdarrayConfig, NdarrayTrace, NdarrayValue};
148+
149+
#[cfg(feature = "arrow")]
150+
pub use storage::{ArrowConfig, ArrowTrace, ArrowTraceStorage};

0 commit comments

Comments
 (0)