|
| 1 | +//! Zarr async backend example for MCMC trace storage |
| 2 | +//! |
| 3 | +//! This example demonstrates how to use the nuts-rs library with async Zarr storage |
| 4 | +//! for running MCMC sampling on a multivariate normal distribution. It shows: |
| 5 | +//! |
| 6 | +//! - Setting up a custom probability model |
| 7 | +//! - Configuring async Zarr storage for results |
| 8 | +//! - Running multiple parallel chains with async I/O |
| 9 | +//! - Monitoring progress during sampling |
| 10 | +//! - Saving results in ArviZ-compatible format |
| 11 | +
|
| 12 | +use std::{ |
| 13 | + collections::HashMap, |
| 14 | + f64, |
| 15 | + sync::Arc, |
| 16 | + time::{Duration, Instant}, |
| 17 | +}; |
| 18 | + |
| 19 | +use anyhow::Result; |
| 20 | +use nuts_rs::{ |
| 21 | + CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler, |
| 22 | + SamplerWaitResult, Storable, ZarrAsyncConfig, |
| 23 | +}; |
| 24 | +use nuts_storable::{HasDims, Value}; |
| 25 | +use rand::Rng; |
| 26 | +use thiserror::Error; |
| 27 | +use zarrs::filesystem::FilesystemStore; |
| 28 | +use zarrs_object_store::AsyncObjectStore; |
| 29 | + |
| 30 | +/// A multivariate normal distribution model |
| 31 | +/// |
| 32 | +/// This represents a probability distribution with mean μ and precision matrix P, |
| 33 | +/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ) |
| 34 | +#[derive(Clone, Debug)] |
| 35 | +struct MultivariateNormal { |
| 36 | + mean: Vec<f64>, |
| 37 | + precision: Vec<Vec<f64>>, // Inverse of covariance matrix |
| 38 | +} |
| 39 | + |
| 40 | +impl MultivariateNormal { |
| 41 | + fn new(mean: Vec<f64>, precision: Vec<Vec<f64>>) -> Self { |
| 42 | + Self { mean, precision } |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +/// Custom error type for log probability calculations |
| 47 | +/// |
| 48 | +/// MCMC samplers need to distinguish between recoverable errors (like numerical |
| 49 | +/// issues that can be handled by rejecting the proposal) and non-recoverable |
| 50 | +/// errors (like programming bugs that should stop sampling). |
| 51 | +#[allow(dead_code)] |
| 52 | +#[derive(Debug, Error)] |
| 53 | +enum MyLogpError { |
| 54 | + #[error("Recoverable error in logp calculation: {0}")] |
| 55 | + Recoverable(String), |
| 56 | + #[error("Non-recoverable error in logp calculation: {0}")] |
| 57 | + NonRecoverable(String), |
| 58 | +} |
| 59 | + |
| 60 | +impl LogpError for MyLogpError { |
| 61 | + fn is_recoverable(&self) -> bool { |
| 62 | + matches!(self, MyLogpError::Recoverable(_)) |
| 63 | + } |
| 64 | +} |
| 65 | + |
| 66 | +/// Implementation of the log probability function for multivariate normal |
| 67 | +/// |
| 68 | +/// This struct contains the model parameters and implements the mathematical |
| 69 | +/// operations needed for MCMC sampling: computing log probability and gradients. |
| 70 | +#[derive(Clone)] |
| 71 | +struct MvnLogp { |
| 72 | + model: MultivariateNormal, |
| 73 | + buffer: Vec<f64>, // Temporary buffer for computations |
| 74 | +} |
| 75 | + |
| 76 | +impl HasDims for MvnLogp { |
| 77 | + /// Define dimension names and sizes for storage |
| 78 | + /// |
| 79 | + /// This tells the storage system what array dimensions to expect. |
| 80 | + /// These dimensions will be used to structure the output data. |
| 81 | + fn dim_sizes(&self) -> HashMap<String, u64> { |
| 82 | + HashMap::from([ |
| 83 | + // Dimension for the actual parameter vector x |
| 84 | + ("x".to_string(), self.model.mean.len() as u64), |
| 85 | + ]) |
| 86 | + } |
| 87 | + |
| 88 | + fn coords(&self) -> HashMap<String, nuts_storable::Value> { |
| 89 | + HashMap::from([( |
| 90 | + "x".to_string(), |
| 91 | + Value::Strings(vec!["x1".to_string(), "x2".to_string()]), |
| 92 | + )]) |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +/// Additional quantities computed from each sample |
| 97 | +/// |
| 98 | +/// The `Storable` derive macro automatically generates code to store this |
| 99 | +/// struct in the trace. The `dims` attribute specifies which dimension |
| 100 | +/// each field should use. |
| 101 | +#[derive(Storable)] |
| 102 | +struct ExpandedDraw { |
| 103 | + /// Store the parameter values with dimension "x" |
| 104 | + #[storable(dims("x"))] |
| 105 | + prec: Vec<f64>, |
| 106 | + /// A scalar derived quantity (difference between first two parameters) |
| 107 | + diff: f64, |
| 108 | +} |
| 109 | + |
| 110 | +impl CpuLogpFunc for MvnLogp { |
| 111 | + type LogpError = MyLogpError; |
| 112 | + type FlowParameters = (); // No parameter transformations needed |
| 113 | + type ExpandedVector = ExpandedDraw; |
| 114 | + |
| 115 | + /// Return the dimensionality of the parameter space |
| 116 | + fn dim(&self) -> usize { |
| 117 | + self.model.mean.len() |
| 118 | + } |
| 119 | + |
| 120 | + /// Compute log probability and gradient |
| 121 | + /// |
| 122 | + /// This is the core mathematical function that MCMC uses to explore |
| 123 | + /// the parameter space. It computes both the log probability density |
| 124 | + /// and its gradient for efficient sampling with Hamiltonian Monte Carlo. |
| 125 | + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> { |
| 126 | + let n = x.len(); |
| 127 | + |
| 128 | + // Compute (x - mean) |
| 129 | + let diff = &mut self.buffer; |
| 130 | + for i in 0..n { |
| 131 | + diff[i] = x[i] - self.model.mean[i]; |
| 132 | + } |
| 133 | + |
| 134 | + let mut quad = 0.0; |
| 135 | + |
| 136 | + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff |
| 137 | + for i in 0..n { |
| 138 | + // Compute i-th component of P * diff |
| 139 | + let mut pdot = 0.0; |
| 140 | + for j in 0..n { |
| 141 | + let pij = self.model.precision[i][j]; |
| 142 | + pdot += pij * diff[j]; |
| 143 | + quad += diff[i] * pij * diff[j]; |
| 144 | + } |
| 145 | + // Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i |
| 146 | + grad[i] = -pdot; |
| 147 | + } |
| 148 | + |
| 149 | + Ok(-0.5 * quad) |
| 150 | + } |
| 151 | + |
| 152 | + /// Compute additional quantities from each sample |
| 153 | + /// |
| 154 | + /// This function is called for each accepted sample to compute derived |
| 155 | + /// quantities that should be stored in the trace. These might be |
| 156 | + /// transformed parameters, predictions, or other quantities of interest. |
| 157 | + fn expand_vector<R: Rng + ?Sized>( |
| 158 | + &mut self, |
| 159 | + _rng: &mut R, |
| 160 | + array: &[f64], |
| 161 | + ) -> Result<Self::ExpandedVector, CpuMathError> { |
| 162 | + // Store the raw parameter values and compute a simple derived quantity |
| 163 | + Ok(ExpandedDraw { |
| 164 | + prec: array.to_vec(), |
| 165 | + diff: array[1] - array[0], // Example: difference between first two parameters |
| 166 | + }) |
| 167 | + } |
| 168 | + |
| 169 | + fn vector_coord(&self) -> Option<Value> { |
| 170 | + Some(Value::Strings(vec!["x1".to_string(), "x2".to_string()])) |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +/// The complete MCMC model |
| 175 | +/// |
| 176 | +/// This struct implements the Model trait, which is the main interface |
| 177 | +/// that samplers use. It provides access to the mathematical operations |
| 178 | +/// and handles initialization of the sampling chains. |
| 179 | +struct MvnModel { |
| 180 | + math: CpuMath<MvnLogp>, |
| 181 | +} |
| 182 | + |
| 183 | +impl Model for MvnModel { |
| 184 | + type Math<'model> |
| 185 | + = CpuMath<MvnLogp> |
| 186 | + where |
| 187 | + Self: 'model; |
| 188 | + |
| 189 | + fn math(&self) -> Result<Self::Math<'_>> { |
| 190 | + Ok(self.math.clone()) |
| 191 | + } |
| 192 | + |
| 193 | + /// Generate random initial positions for the chain |
| 194 | + /// |
| 195 | + /// Good initialization is important for MCMC efficiency. The starting |
| 196 | + /// points should be in a reasonable region of the parameter space |
| 197 | + /// where the log probability is finite. |
| 198 | + fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { |
| 199 | + // Initialize each parameter randomly in the range [-2, 2] |
| 200 | + // For this simple example, this should put us in a reasonable |
| 201 | + // region around the mode of the distribution |
| 202 | + for p in position.iter_mut() { |
| 203 | + *p = rng.random_range(-2.0..2.0); |
| 204 | + } |
| 205 | + Ok(()) |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +fn main() -> Result<()> { |
| 210 | + println!("=== Multivariate Normal MCMC with Async Zarr Storage ===\n"); |
| 211 | + |
| 212 | + // Create a 2D multivariate normal distribution |
| 213 | + // This creates a distribution with mean [0, 0] and precision matrix |
| 214 | + // [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation |
| 215 | + // between the two parameters |
| 216 | + let mean = vec![0.0, 0.0]; |
| 217 | + let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]]; |
| 218 | + let mvn = MultivariateNormal::new(mean, precision); |
| 219 | + |
| 220 | + println!("Model: 2D Multivariate Normal"); |
| 221 | + println!("Mean: {:?}", mvn.mean); |
| 222 | + println!("Precision matrix: {:?}\n", mvn.precision); |
| 223 | + |
| 224 | + // Configure output location |
| 225 | + let output_path = "mcmc_output/async_trace.zarr"; |
| 226 | + println!("Output will be saved to: {}\n", output_path); |
| 227 | + |
| 228 | + // Sampling configuration |
| 229 | + let num_chains = 4; // Run 4 parallel chains for better convergence assessment |
| 230 | + let num_tune = 500; // Warmup samples to tune the sampler |
| 231 | + let num_draws = 500; // Post-warmup samples to keep |
| 232 | + |
| 233 | + println!("Sampling configuration:"); |
| 234 | + println!(" Chains: {}", num_chains); |
| 235 | + println!(" Warmup samples: {}", num_tune); |
| 236 | + println!(" Sampling draws: {}", num_draws); |
| 237 | + |
| 238 | + // Configure MCMC settings |
| 239 | + // DiagGradNutsSettings provides sensible defaults for the NUTS sampler |
| 240 | + let mut settings = DiagGradNutsSettings::default(); |
| 241 | + settings.num_chains = num_chains as _; |
| 242 | + settings.num_tune = num_tune; |
| 243 | + settings.num_draws = num_draws as _; |
| 244 | + settings.seed = 54; // For reproducible results |
| 245 | + |
| 246 | + let path = std::path::Path::new(output_path).canonicalize()?; |
| 247 | + let object_store = object_store::local::LocalFileSystem::new_with_prefix(path)?; |
| 248 | + let store = Arc::new(AsyncObjectStore::new(object_store)); |
| 249 | + |
| 250 | + // Create the model instance |
| 251 | + let model = MvnModel { |
| 252 | + math: CpuMath::new(MvnLogp { |
| 253 | + model: mvn, |
| 254 | + buffer: vec![0.0; 2], |
| 255 | + }), |
| 256 | + }; |
| 257 | + |
| 258 | + // Start sampling |
| 259 | + println!("\nStarting MCMC sampling with async Zarr backend...\n"); |
| 260 | + let start = Instant::now(); |
| 261 | + |
| 262 | + // Configure async Zarr storage with default settings |
| 263 | + // This uses async I/O operations to avoid blocking during writes |
| 264 | + let rt = tokio::runtime::Builder::new_multi_thread() |
| 265 | + .worker_threads(4) |
| 266 | + .enable_all() |
| 267 | + .build() |
| 268 | + .unwrap(); |
| 269 | + let handle = rt.handle().clone(); |
| 270 | + let zarr_async_config = ZarrAsyncConfig::new(handle, store.clone()); |
| 271 | + |
| 272 | + // Create sampler with 4 worker threads |
| 273 | + // The sampler runs asynchronously, so we can monitor progress |
| 274 | + let mut sampler = Some(Sampler::new(model, settings, zarr_async_config, 4, None)?); |
| 275 | + |
| 276 | + let mut num_progress_updates = 0; |
| 277 | + |
| 278 | + // Main sampling loop with progress monitoring |
| 279 | + // This demonstrates how to monitor long-running sampling jobs |
| 280 | + while let Some(sampler_) = sampler.take() { |
| 281 | + match sampler_.wait_timeout(Duration::from_millis(50)) { |
| 282 | + // Sampling completed successfully |
| 283 | + SamplerWaitResult::Trace(_) => { |
| 284 | + println!("✓ Async sampling completed in {:?}", start.elapsed()); |
| 285 | + println!("✓ Traces written to Zarr format at '{}'", output_path); |
| 286 | + |
| 287 | + // Provide instructions for analysis |
| 288 | + println!("\n=== Next Steps ==="); |
| 289 | + println!("To analyze results in Python with ArviZ:"); |
| 290 | + println!(" import arviz as az"); |
| 291 | + println!(" data = az.from_zarr('{}')", output_path); |
| 292 | + println!(" az.plot_trace(data)"); |
| 293 | + println!(" az.summary(data)"); |
| 294 | + println!("\nThe async Zarr format contains:"); |
| 295 | + println!(" - posterior/: Main sampling results"); |
| 296 | + println!(" - sample_stats/: Sampler diagnostics"); |
| 297 | + println!(" - warmup_*: Warmup phase results"); |
| 298 | + println!("\nNote: The async backend uses tokio tasks for I/O operations,"); |
| 299 | + println!(" which can improve performance by avoiding blocking writes."); |
| 300 | + break; |
| 301 | + } |
| 302 | + |
| 303 | + // Timeout - sampler is still running, show progress |
| 304 | + SamplerWaitResult::Timeout(mut sampler_) => { |
| 305 | + num_progress_updates += 1; |
| 306 | + println!("Progress update {} (async I/O):", num_progress_updates); |
| 307 | + |
| 308 | + // Get current progress from all chains |
| 309 | + let progress = sampler_.progress()?; |
| 310 | + for (i, chain) in progress.iter().enumerate() { |
| 311 | + let phase = if chain.tuning { "warmup" } else { "sampling" }; |
| 312 | + println!( |
| 313 | + " Chain {}: {} samples ({} divergences), step size: {:.6} [{}]", |
| 314 | + i, chain.finished_draws, chain.divergences, chain.step_size, phase |
| 315 | + ); |
| 316 | + } |
| 317 | + println!(" (Zarr writes are happening asynchronously in the background)"); |
| 318 | + println!(); // Add blank line for readability |
| 319 | + |
| 320 | + sampler = Some(sampler_); |
| 321 | + } |
| 322 | + |
| 323 | + // An error occurred during sampling |
| 324 | + SamplerWaitResult::Err(err, _) => { |
| 325 | + eprintln!("✗ Async sampling failed: {}", err); |
| 326 | + return Err(err); |
| 327 | + } |
| 328 | + } |
| 329 | + } |
| 330 | + |
| 331 | + Ok(()) |
| 332 | +} |
0 commit comments