Skip to content

Commit 36071c9

Browse files
committed
feat: implement async zarr storage
1 parent 824061f commit 36071c9

File tree

7 files changed

+1303
-260
lines changed

7 files changed

+1303
-260
lines changed

Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ zarrs = { version = "0.21.0", features = [
3232
"filesystem",
3333
"gzip",
3434
"sharding",
35+
"async",
3536
], optional = true }
3637
ndarray = { version = "0.16.1", optional = true }
3738
nuts-derive = { path = "./nuts-derive" }
3839
nuts-storable = { path = "./nuts-storable" }
3940
serde = { version = "1.0.219", features = ["derive"] }
4041
serde_json = "1.0"
42+
tokio = { version = "1.0", features = ["rt"], optional = true }
4143

4244
[dev-dependencies]
4345
proptest = "1.6.0"
@@ -49,9 +51,12 @@ equator = "0.4.2"
4951
serde_json = "1.0"
5052
ndarray = "0.16.1"
5153
tempfile = "3.0"
54+
zarrs_object_store = "0.4.3"
55+
object_store = "0.12.0"
56+
tokio = { version = "1.0", features = ["rt", "rt-multi-thread"]}
5257

5358
[features]
54-
zarr = ["dep:zarrs"]
59+
zarr = ["dep:zarrs", "dep:tokio"]
5560
ndarray = ["dep:ndarray"]
5661

5762
[[bench]]

examples/zarr_async_trace.rs

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
//! Zarr async backend example for MCMC trace storage
2+
//!
3+
//! This example demonstrates how to use the nuts-rs library with async Zarr storage
4+
//! for running MCMC sampling on a multivariate normal distribution. It shows:
5+
//!
6+
//! - Setting up a custom probability model
7+
//! - Configuring async Zarr storage for results
8+
//! - Running multiple parallel chains with async I/O
9+
//! - Monitoring progress during sampling
10+
//! - Saving results in ArviZ-compatible format
11+
12+
use std::{
13+
collections::HashMap,
14+
f64,
15+
sync::Arc,
16+
time::{Duration, Instant},
17+
};
18+
19+
use anyhow::Result;
20+
use nuts_rs::{
21+
CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler,
22+
SamplerWaitResult, Storable, ZarrAsyncConfig,
23+
};
24+
use nuts_storable::{HasDims, Value};
25+
use rand::Rng;
26+
use thiserror::Error;
27+
use zarrs::filesystem::FilesystemStore;
28+
use zarrs_object_store::AsyncObjectStore;
29+
30+
/// A multivariate normal distribution model
31+
///
32+
/// This represents a probability distribution with mean μ and precision matrix P,
33+
/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ)
34+
#[derive(Clone, Debug)]
35+
struct MultivariateNormal {
36+
mean: Vec<f64>,
37+
precision: Vec<Vec<f64>>, // Inverse of covariance matrix
38+
}
39+
40+
impl MultivariateNormal {
41+
fn new(mean: Vec<f64>, precision: Vec<Vec<f64>>) -> Self {
42+
Self { mean, precision }
43+
}
44+
}
45+
46+
/// Custom error type for log probability calculations
47+
///
48+
/// MCMC samplers need to distinguish between recoverable errors (like numerical
49+
/// issues that can be handled by rejecting the proposal) and non-recoverable
50+
/// errors (like programming bugs that should stop sampling).
51+
#[allow(dead_code)]
52+
#[derive(Debug, Error)]
53+
enum MyLogpError {
54+
#[error("Recoverable error in logp calculation: {0}")]
55+
Recoverable(String),
56+
#[error("Non-recoverable error in logp calculation: {0}")]
57+
NonRecoverable(String),
58+
}
59+
60+
impl LogpError for MyLogpError {
61+
fn is_recoverable(&self) -> bool {
62+
matches!(self, MyLogpError::Recoverable(_))
63+
}
64+
}
65+
66+
/// Implementation of the log probability function for multivariate normal
67+
///
68+
/// This struct contains the model parameters and implements the mathematical
69+
/// operations needed for MCMC sampling: computing log probability and gradients.
70+
#[derive(Clone)]
71+
struct MvnLogp {
72+
model: MultivariateNormal,
73+
buffer: Vec<f64>, // Temporary buffer for computations
74+
}
75+
76+
impl HasDims for MvnLogp {
77+
/// Define dimension names and sizes for storage
78+
///
79+
/// This tells the storage system what array dimensions to expect.
80+
/// These dimensions will be used to structure the output data.
81+
fn dim_sizes(&self) -> HashMap<String, u64> {
82+
HashMap::from([
83+
// Dimension for the actual parameter vector x
84+
("x".to_string(), self.model.mean.len() as u64),
85+
])
86+
}
87+
88+
fn coords(&self) -> HashMap<String, nuts_storable::Value> {
89+
HashMap::from([(
90+
"x".to_string(),
91+
Value::Strings(vec!["x1".to_string(), "x2".to_string()]),
92+
)])
93+
}
94+
}
95+
96+
/// Additional quantities computed from each sample
97+
///
98+
/// The `Storable` derive macro automatically generates code to store this
99+
/// struct in the trace. The `dims` attribute specifies which dimension
100+
/// each field should use.
101+
#[derive(Storable)]
102+
struct ExpandedDraw {
103+
/// Store the parameter values with dimension "x"
104+
#[storable(dims("x"))]
105+
prec: Vec<f64>,
106+
/// A scalar derived quantity (difference between first two parameters)
107+
diff: f64,
108+
}
109+
110+
impl CpuLogpFunc for MvnLogp {
111+
type LogpError = MyLogpError;
112+
type FlowParameters = (); // No parameter transformations needed
113+
type ExpandedVector = ExpandedDraw;
114+
115+
/// Return the dimensionality of the parameter space
116+
fn dim(&self) -> usize {
117+
self.model.mean.len()
118+
}
119+
120+
/// Compute log probability and gradient
121+
///
122+
/// This is the core mathematical function that MCMC uses to explore
123+
/// the parameter space. It computes both the log probability density
124+
/// and its gradient for efficient sampling with Hamiltonian Monte Carlo.
125+
fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
126+
let n = x.len();
127+
128+
// Compute (x - mean)
129+
let diff = &mut self.buffer;
130+
for i in 0..n {
131+
diff[i] = x[i] - self.model.mean[i];
132+
}
133+
134+
let mut quad = 0.0;
135+
136+
// Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff
137+
for i in 0..n {
138+
// Compute i-th component of P * diff
139+
let mut pdot = 0.0;
140+
for j in 0..n {
141+
let pij = self.model.precision[i][j];
142+
pdot += pij * diff[j];
143+
quad += diff[i] * pij * diff[j];
144+
}
145+
// Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i
146+
grad[i] = -pdot;
147+
}
148+
149+
Ok(-0.5 * quad)
150+
}
151+
152+
/// Compute additional quantities from each sample
153+
///
154+
/// This function is called for each accepted sample to compute derived
155+
/// quantities that should be stored in the trace. These might be
156+
/// transformed parameters, predictions, or other quantities of interest.
157+
fn expand_vector<R: Rng + ?Sized>(
158+
&mut self,
159+
_rng: &mut R,
160+
array: &[f64],
161+
) -> Result<Self::ExpandedVector, CpuMathError> {
162+
// Store the raw parameter values and compute a simple derived quantity
163+
Ok(ExpandedDraw {
164+
prec: array.to_vec(),
165+
diff: array[1] - array[0], // Example: difference between first two parameters
166+
})
167+
}
168+
169+
fn vector_coord(&self) -> Option<Value> {
170+
Some(Value::Strings(vec!["x1".to_string(), "x2".to_string()]))
171+
}
172+
}
173+
174+
/// The complete MCMC model
175+
///
176+
/// This struct implements the Model trait, which is the main interface
177+
/// that samplers use. It provides access to the mathematical operations
178+
/// and handles initialization of the sampling chains.
179+
struct MvnModel {
180+
math: CpuMath<MvnLogp>,
181+
}
182+
183+
impl Model for MvnModel {
184+
type Math<'model>
185+
= CpuMath<MvnLogp>
186+
where
187+
Self: 'model;
188+
189+
fn math(&self) -> Result<Self::Math<'_>> {
190+
Ok(self.math.clone())
191+
}
192+
193+
/// Generate random initial positions for the chain
194+
///
195+
/// Good initialization is important for MCMC efficiency. The starting
196+
/// points should be in a reasonable region of the parameter space
197+
/// where the log probability is finite.
198+
fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
199+
// Initialize each parameter randomly in the range [-2, 2]
200+
// For this simple example, this should put us in a reasonable
201+
// region around the mode of the distribution
202+
for p in position.iter_mut() {
203+
*p = rng.random_range(-2.0..2.0);
204+
}
205+
Ok(())
206+
}
207+
}
208+
209+
fn main() -> Result<()> {
210+
println!("=== Multivariate Normal MCMC with Async Zarr Storage ===\n");
211+
212+
// Create a 2D multivariate normal distribution
213+
// This creates a distribution with mean [0, 0] and precision matrix
214+
// [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation
215+
// between the two parameters
216+
let mean = vec![0.0, 0.0];
217+
let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
218+
let mvn = MultivariateNormal::new(mean, precision);
219+
220+
println!("Model: 2D Multivariate Normal");
221+
println!("Mean: {:?}", mvn.mean);
222+
println!("Precision matrix: {:?}\n", mvn.precision);
223+
224+
// Configure output location
225+
let output_path = "mcmc_output/async_trace.zarr";
226+
println!("Output will be saved to: {}\n", output_path);
227+
228+
// Sampling configuration
229+
let num_chains = 4; // Run 4 parallel chains for better convergence assessment
230+
let num_tune = 500; // Warmup samples to tune the sampler
231+
let num_draws = 500; // Post-warmup samples to keep
232+
233+
println!("Sampling configuration:");
234+
println!(" Chains: {}", num_chains);
235+
println!(" Warmup samples: {}", num_tune);
236+
println!(" Sampling draws: {}", num_draws);
237+
238+
// Configure MCMC settings
239+
// DiagGradNutsSettings provides sensible defaults for the NUTS sampler
240+
let mut settings = DiagGradNutsSettings::default();
241+
settings.num_chains = num_chains as _;
242+
settings.num_tune = num_tune;
243+
settings.num_draws = num_draws as _;
244+
settings.seed = 54; // For reproducible results
245+
246+
let path = std::path::Path::new(output_path).canonicalize()?;
247+
let object_store = object_store::local::LocalFileSystem::new_with_prefix(path)?;
248+
let store = Arc::new(AsyncObjectStore::new(object_store));
249+
250+
// Create the model instance
251+
let model = MvnModel {
252+
math: CpuMath::new(MvnLogp {
253+
model: mvn,
254+
buffer: vec![0.0; 2],
255+
}),
256+
};
257+
258+
// Start sampling
259+
println!("\nStarting MCMC sampling with async Zarr backend...\n");
260+
let start = Instant::now();
261+
262+
// Configure async Zarr storage with default settings
263+
// This uses async I/O operations to avoid blocking during writes
264+
let rt = tokio::runtime::Builder::new_multi_thread()
265+
.worker_threads(4)
266+
.enable_all()
267+
.build()
268+
.unwrap();
269+
let handle = rt.handle().clone();
270+
let zarr_async_config = ZarrAsyncConfig::new(handle, store.clone());
271+
272+
// Create sampler with 4 worker threads
273+
// The sampler runs asynchronously, so we can monitor progress
274+
let mut sampler = Some(Sampler::new(model, settings, zarr_async_config, 4, None)?);
275+
276+
let mut num_progress_updates = 0;
277+
278+
// Main sampling loop with progress monitoring
279+
// This demonstrates how to monitor long-running sampling jobs
280+
while let Some(sampler_) = sampler.take() {
281+
match sampler_.wait_timeout(Duration::from_millis(50)) {
282+
// Sampling completed successfully
283+
SamplerWaitResult::Trace(_) => {
284+
println!("✓ Async sampling completed in {:?}", start.elapsed());
285+
println!("✓ Traces written to Zarr format at '{}'", output_path);
286+
287+
// Provide instructions for analysis
288+
println!("\n=== Next Steps ===");
289+
println!("To analyze results in Python with ArviZ:");
290+
println!(" import arviz as az");
291+
println!(" data = az.from_zarr('{}')", output_path);
292+
println!(" az.plot_trace(data)");
293+
println!(" az.summary(data)");
294+
println!("\nThe async Zarr format contains:");
295+
println!(" - posterior/: Main sampling results");
296+
println!(" - sample_stats/: Sampler diagnostics");
297+
println!(" - warmup_*: Warmup phase results");
298+
println!("\nNote: The async backend uses tokio tasks for I/O operations,");
299+
println!(" which can improve performance by avoiding blocking writes.");
300+
break;
301+
}
302+
303+
// Timeout - sampler is still running, show progress
304+
SamplerWaitResult::Timeout(mut sampler_) => {
305+
num_progress_updates += 1;
306+
println!("Progress update {} (async I/O):", num_progress_updates);
307+
308+
// Get current progress from all chains
309+
let progress = sampler_.progress()?;
310+
for (i, chain) in progress.iter().enumerate() {
311+
let phase = if chain.tuning { "warmup" } else { "sampling" };
312+
println!(
313+
" Chain {}: {} samples ({} divergences), step size: {:.6} [{}]",
314+
i, chain.finished_draws, chain.divergences, chain.step_size, phase
315+
);
316+
}
317+
println!(" (Zarr writes are happening asynchronously in the background)");
318+
println!(); // Add blank line for readability
319+
320+
sampler = Some(sampler_);
321+
}
322+
323+
// An error occurred during sampling
324+
SamplerWaitResult::Err(err, _) => {
325+
eprintln!("✗ Async sampling failed: {}", err);
326+
return Err(err);
327+
}
328+
}
329+
}
330+
331+
Ok(())
332+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ pub use stepsize_adapt::{StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSett
148148
pub use transform_adapt_strategy::TransformedSettings;
149149

150150
#[cfg(feature = "zarr")]
151-
pub use zarr_storage::{ZarrConfig, ZarrTraceStorage};
151+
pub use zarr_storage::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage};
152152

153153
pub use csv_storage::{CsvConfig, CsvTraceStorage};
154154
pub use hashmap_storage::{HashMapConfig, HashMapValue};

0 commit comments

Comments
 (0)