diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 151301f..7832d5e 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -14,7 +14,7 @@ jobs: - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: Generate code coverage - run: cargo llvm-cov --workspace --lcov --output-path lcov.info + run: cargo llvm-cov --workspace --lcov --output-path lcov.info --all-features - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 env: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 06da3ba..9b65651 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,36 +2,10 @@ on: [push, pull_request] name: Run tests -jobs: - check: - name: Check - strategy: - matrix: - rust: - - stable - - nightly - runs-on: ubuntu-latest - steps: - - name: Checkout sources - uses: actions/checkout@v2 - - - name: Install stable toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: ${{ matrix.rust }} - override: true - - - name: Run cargo check nightly features - if: ${{ matrix.rust == 'nightly' }} - uses: actions-rs/cargo@v1 - with: - command: check - - name: Run cargo check - uses: actions-rs/cargo@v1 - with: - command: check +env: + CARGO_TERM_COLOR: always +jobs: test: name: Test Suite runs-on: ubuntu-latest @@ -42,21 +16,7 @@ jobs: - nightly steps: - name: Checkout sources - uses: actions/checkout@v2 - - - name: Install stable toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true - - - name: Run cargo test with nightly features - if: ${{ matrix.rust == 'nightly' }} - uses: actions-rs/cargo@v1 - with: - command: test - - name: Run cargo test - uses: actions-rs/cargo@v1 - with: - command: test + uses: actions/checkout@v4 + - run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} + - run: cargo build --verbose --all-features + - run: cargo test --verbose --all-features diff --git a/Cargo.toml b/Cargo.toml index 47c5845..85769ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ authors = [ "Adrian Seyboldt ", "PyMC Developers ", ] -edition = "2021" +edition = "2024" license = "MIT" repository = "https://github.com/pymc-devs/nuts-rs" keywords = ["statistics", "bayes"] @@ -22,22 +22,44 @@ rand = { version = "0.9.0", features = ["small_rng"] } rand_distr = "0.5.0" itertools = "0.14.0" thiserror = "2.0.3" -arrow = { version = "55.1.0", default-features = false, features = ["ffi"] } rand_chacha = "0.9.0" anyhow = "1.0.72" faer = { version = "0.22.6", default-features = false, features = ["linalg"] } pulp = "0.21.4" rayon = "1.10.0" +zarrs = { version = "0.21.0", features = [ + "filesystem", + "gzip", + "sharding", + "async", +], optional = true } +ndarray = { version = "0.16.1", optional = true } +nuts-derive = { path = "./nuts-derive" } +nuts-storable = { path = "./nuts-storable" } +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1.0", features = ["rt"], optional = true } [dev-dependencies] proptest = "1.6.0" pretty_assertions = "1.4.0" -criterion = "0.6.0" +criterion = "0.7.0" nix = { version = "0.30.0", features = ["sched"] } approx = "0.5.1" -ndarray = "0.16.1" equator = "0.4.2" +serde_json = "1.0" +ndarray = "0.16.1" +tempfile = "3.0" +zarrs_object_store = "0.4.3" +object_store = "0.12.0" +tokio = { version = "1.0", features = ["rt", "rt-multi-thread"] } + +[features] +zarr = ["dep:zarrs", "dep:tokio"] +ndarray = ["dep:ndarray"] [[bench]] name = "sample" harness = false + +[workspace] diff --git a/benches/sample.rs b/benches/sample.rs index 9148493..54ee314 100644 --- a/benches/sample.rs +++ b/benches/sample.rs @@ -1,9 +1,10 @@ use std::hint::black_box; -use criterion::{criterion_group, criterion_main, Criterion}; -use nix::sched::{sched_setaffinity, CpuSet}; +use criterion::{Criterion, criterion_group, criterion_main}; +use nix::sched::{CpuSet, sched_setaffinity}; use nix::unistd::Pid; use nuts_rs::{Chain, CpuLogpFunc, CpuMath, LogpError, Math, Settings}; +use nuts_storable::HasDims; use rand::SeedableRng; use rayon::ThreadPoolBuilder; use thiserror::Error; @@ -22,11 +23,20 @@ impl LogpError for PosteriorLogpError { } } +impl HasDims for PosteriorDensity { + fn dim_sizes(&self) -> std::collections::HashMap { + vec![("unconstrained_parameter".to_string(), self.dim() as u64)] + .into_iter() + .collect() + } +} + impl CpuLogpFunc for PosteriorDensity { type LogpError = PosteriorLogpError; + type ExpandedVector = Vec; // Only used for transforming adaptation. - type TransformParams = (); + type FlowParameters = (); // We define a 10 dimensional normal distribution fn dim(&self) -> usize { @@ -48,6 +58,17 @@ impl CpuLogpFunc for PosteriorDensity { .sum(); return Ok(logp); } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result + where + R: rand::Rng + ?Sized, + { + Ok(array.to_vec()) + } } fn make_sampler(dim: usize) -> impl Chain> { diff --git a/examples/adam_adaptation.rs b/examples/adam_adaptation.rs new file mode 100644 index 0000000..227249c --- /dev/null +++ b/examples/adam_adaptation.rs @@ -0,0 +1,178 @@ +//! Example demonstrating the Adam optimizer for step size adaptation. +//! +//! This example shows how to use the Adam optimizer instead of dual averaging +//! for adapting the step size in NUTS. + +use nuts_rs::{ + AdamOptions, Chain, CpuLogpFunc, CpuMath, DiagGradNutsSettings, LogpError, Settings, + StepSizeAdaptMethod, +}; +use nuts_storable::HasDims; +use thiserror::Error; + +// Define a function that computes the unnormalized posterior density +// and its gradient. +#[derive(Debug)] +struct PosteriorDensity {} + +// The density might fail in a recoverable or non-recoverable manner... +#[derive(Debug, Error)] +enum PosteriorLogpError {} +impl LogpError for PosteriorLogpError { + fn is_recoverable(&self) -> bool { + false + } +} + +impl HasDims for PosteriorDensity { + fn dim_sizes(&self) -> std::collections::HashMap { + vec![("unconstrained_parameter".to_string(), self.dim() as u64)] + .into_iter() + .collect() + } +} + +impl CpuLogpFunc for PosteriorDensity { + type LogpError = PosteriorLogpError; + type ExpandedVector = Vec; + + // Only used for transforming adaptation. + type FlowParameters = (); + + // We define a 10 dimensional normal distribution + fn dim(&self) -> usize { + 10 + } + + // The normal likelihood with mean 3 and its gradient. + fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { + let mu = 3f64; + let logp = position + .iter() + .copied() + .zip(grad.iter_mut()) + .map(|(x, grad)| { + let diff = x - mu; + *grad = -diff; + -diff * diff / 2f64 + }) + .sum(); + return Ok(logp); + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result + where + R: rand::Rng + ?Sized, + { + Ok(array.to_vec()) + } +} + +fn main() { + println!("Running NUTS with Adam step size adaptation..."); + + // Create sampler settings with Adam optimizer + let mut settings = DiagGradNutsSettings::default(); + + // Configure for Adam adaptation + settings + .adapt_options + .step_size_settings + .adapt_options + .method = StepSizeAdaptMethod::Adam; + + // Set Adam options + let adam_options = AdamOptions { + beta1: 0.9, + beta2: 0.999, + epsilon: 1e-8, + learning_rate: 0.05, + }; + + settings.adapt_options.step_size_settings.adapt_options.adam = adam_options; + + // Standard MCMC settings + settings.num_tune = 1000; + settings.num_draws = 1000; + settings.maxdepth = 10; + + // Create the posterior density function + let logp_func = PosteriorDensity {}; + let math = CpuMath::new(logp_func); + + // Initialize the sampler + let chain = 0; + let mut rng = rand::rng(); + let mut sampler = settings.new_chain(chain, math, &mut rng); + + // Set initial position + let initial_position = vec![0f64; 10]; + sampler + .set_position(&initial_position) + .expect("Unrecoverable error during init"); + + // Collect samples + let mut trace = vec![]; + let mut stats = vec![]; + + // Sampling with progress reporting + println!("Warmup phase:"); + for i in 0..settings.num_tune { + if i % 100 == 0 { + println!("\rWarmup: {}/{}", i, settings.num_tune); + } + + let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling"); + println!("{:?}", info.step_size); + trace.push(draw); + stats.push(info); + } + println!("\rWarmup: {}/{}", settings.num_tune, settings.num_tune); + + println!("\nSampling phase:"); + for i in 0..settings.num_draws { + if i % 100 == 0 { + print!("\rSampling: {}/{}", i, settings.num_draws); + } + + let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling"); + trace.push(draw); + stats.push(info); + } + println!("\rSampling: {}/{}", settings.num_draws, settings.num_draws); + + // Calculate mean of samples (post-warmup) + let warmup_samples = settings.num_tune as usize; + let mut means = vec![0.0; 10]; + + for i in warmup_samples..trace.len() { + for (j, mean) in means.iter_mut().enumerate() { + *mean += trace[i][j]; + } + } + + for mean in &mut means { + *mean /= settings.num_draws as f64; + } + + // Print results + println!("\nResults after {} samples:", settings.num_draws); + println!("Target mean: 3.0 for all dimensions"); + println!("Estimated means:"); + for (i, mean) in means.iter().enumerate() { + println!("Dimension {}: {:.4}", i, mean); + } + + // Print adaptation statistics + let last_stats = &stats[stats.len() - 1]; + println!("\nFinal adaptation statistics:"); + println!("Step size: {:.6}", last_stats.step_size); + // Note: the full acceptance stats are in the Progress struct, but we don't have direct access to mean_tree_accept + println!("Number of steps: {}", last_stats.num_steps); + + println!("\nSampling completed successfully!"); +} diff --git a/examples/csv_trace.rs b/examples/csv_trace.rs new file mode 100644 index 0000000..e0f5090 --- /dev/null +++ b/examples/csv_trace.rs @@ -0,0 +1,334 @@ +//! CSV backend example for MCMC trace storage +//! +//! This example demonstrates how to use the nuts-rs library with CSV storage +//! for running MCMC sampling on a multivariate normal distribution. It shows: +//! +//! - Setting up a custom probability model +//! - Configuring CSV storage for results +//! - Running multiple parallel chains +//! - Monitoring progress during sampling +//! - Saving results in CmdStan-compatible CSV format + +use std::{ + collections::HashMap, + f64, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use nuts_rs::{ + CpuLogpFunc, CpuMath, CpuMathError, CsvConfig, DiagGradNutsSettings, LogpError, Model, Sampler, + SamplerWaitResult, Storable, +}; +use nuts_storable::{HasDims, Value}; +use rand::Rng; +use thiserror::Error; + +/// A multivariate normal distribution model +/// +/// This represents a probability distribution with mean μ and precision matrix P, +/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ) +#[derive(Clone, Debug)] +struct MultivariateNormal { + mean: Vec, + precision: Vec>, // Inverse of covariance matrix +} + +impl MultivariateNormal { + fn new(mean: Vec, precision: Vec>) -> Self { + Self { mean, precision } + } +} + +/// Custom error type for log probability calculations +/// +/// MCMC samplers need to distinguish between recoverable errors (like numerical +/// issues that can be handled by rejecting the proposal) and non-recoverable +/// errors (like programming bugs that should stop sampling). +#[allow(dead_code)] +#[derive(Debug, Error)] +enum MyLogpError { + #[error("Recoverable error in logp calculation: {0}")] + Recoverable(String), + #[error("Non-recoverable error in logp calculation: {0}")] + NonRecoverable(String), +} + +impl LogpError for MyLogpError { + fn is_recoverable(&self) -> bool { + matches!(self, MyLogpError::Recoverable(_)) + } +} + +/// Implementation of the log probability function for multivariate normal +/// +/// This struct contains the model parameters and implements the mathematical +/// operations needed for MCMC sampling: computing log probability and gradients. +#[derive(Clone)] +struct MvnLogp { + model: MultivariateNormal, +} + +impl HasDims for MvnLogp { + /// Define dimension names and sizes for storage + /// + /// This tells the storage system what array dimensions to expect. + /// These dimensions will be used to structure the output data. + fn dim_sizes(&self) -> HashMap { + HashMap::from([ + // Dimension for the parameter vector (for CSV storage) + ("param".to_string(), self.model.mean.len() as u64), + ]) + } + + fn coords(&self) -> HashMap { + HashMap::from([( + "param".to_string(), + Value::Strings(vec!["mu1".to_string(), "mu2".to_string()]), + )]) + } +} + +/// Additional quantities computed from each sample +/// +/// The `Storable` derive macro automatically generates code to store this +/// struct in the trace. The `dims` attribute specifies which dimension +/// each field should use. +#[derive(Storable)] +struct ExpandedDraw { + /// Store the parameter values as a vector with dimension "param" + #[storable(dims("param"))] + parameters: Vec, +} + +impl CpuLogpFunc for MvnLogp { + type LogpError = MyLogpError; + type FlowParameters = (); // No parameter transformations needed + type ExpandedVector = ExpandedDraw; + + /// Return the dimensionality of the parameter space + fn dim(&self) -> usize { + self.model.mean.len() + } + + /// Compute log probability and gradient + /// + /// This is the core mathematical function that MCMC uses to explore + /// the parameter space. It computes both the log probability density + /// and its gradient for efficient sampling with Hamiltonian Monte Carlo. + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let n = x.len(); + + // Compute (x - mean) + let mut diff = vec![0.0; n]; + for i in 0..n { + diff[i] = x[i] - self.model.mean[i]; + } + + let mut quad = 0.0; + + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff + for i in 0..n { + // Compute i-th component of P * diff + let mut pdot = 0.0; + for j in 0..n { + let pij = self.model.precision[i][j]; + pdot += pij * diff[j]; + quad += diff[i] * pij * diff[j]; + } + // Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i + grad[i] = -pdot; + } + + Ok(-0.5 * quad) + } + + /// Compute additional quantities from each sample + /// + /// This function is called for each accepted sample to compute derived + /// quantities that should be stored in the trace. These might be + /// transformed parameters, predictions, or other quantities of interest. + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + // Return the parameter vector for CSV storage + Ok(ExpandedDraw { + parameters: array.to_vec(), + }) + } + + fn vector_coord(&self) -> Option { + Some(Value::Strings(vec!["mu1".to_string(), "mu2".to_string()])) + } +} + +/// The complete MCMC model +/// +/// This struct implements the Model trait, which is the main interface +/// that samplers use. It provides access to the mathematical operations +/// and handles initialization of the sampling chains. +struct MvnModel { + math: CpuMath, +} + +impl Model for MvnModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + /// Generate random initial positions for the chain + /// + /// Good initialization is important for MCMC efficiency. The starting + /// points should be in a reasonable region of the parameter space + /// where the log probability is finite. + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + // Initialize each parameter randomly in the range [-2, 2] + // For this simple example, this should put us in a reasonable + // region around the mode of the distribution + for p in position.iter_mut() { + *p = rng.random_range(-2.0..2.0); + } + Ok(()) + } +} + +fn main() -> Result<()> { + println!("=== Multivariate Normal MCMC with CSV Storage ===\n"); + + // Create a 2D multivariate normal distribution + // This creates a distribution with mean [0, 0] and precision matrix + // [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation + // between the two parameters + let mean = vec![0.0, 0.0]; + let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]]; + let mvn = MultivariateNormal::new(mean, precision); + + println!("Model: 2D Multivariate Normal"); + println!("Mean: {:?}", mvn.mean); + println!("Precision matrix: {:?}\n", mvn.precision); + + // Configure output location + let output_path = "csv_output"; + println!("Output will be saved to: {}/\n", output_path); + + // Sampling configuration + let num_chains = 4; // Run 4 parallel chains for better convergence assessment + let num_tune = 500; // Warmup samples to tune the sampler + let num_draws = 500; // Post-warmup samples to keep + + println!("Sampling configuration:"); + println!(" Chains: {}", num_chains); + println!(" Warmup samples: {}", num_tune); + println!(" Sampling draws: {}", num_draws); + + // Configure MCMC settings + // DiagGradNutsSettings provides sensible defaults for the NUTS sampler + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = num_chains as _; + settings.num_tune = num_tune; + settings.num_draws = num_draws as _; + settings.seed = 54; // For reproducible results + + // Set up CSV storage + // This will create one CSV file per chain in the specified directory + let csv_config = CsvConfig::new(output_path) + .with_precision(6) // 6 decimal places for floating point values + .store_warmup(true); // Include warmup samples with negative sample IDs + + // Create the model instance + let model = MvnModel { + math: CpuMath::new(MvnLogp { model: mvn }), + }; + + // Start sampling + println!("\nStarting MCMC sampling...\n"); + let start = Instant::now(); + + // Create sampler with 4 worker threads + // The sampler runs asynchronously, so we can monitor progress + let mut sampler = Some(Sampler::new(model, settings, csv_config, 4, None)?); + + let mut num_progress_updates = 0; + + // Main sampling loop with progress monitoring + // This demonstrates how to monitor long-running sampling jobs + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(Duration::from_millis(50)) { + // Sampling completed successfully + SamplerWaitResult::Trace(_) => { + println!("✓ Sampling completed in {:?}", start.elapsed()); + println!("✓ Traces written to CSV format in '{}'", output_path); + + // List the output files + if let Ok(entries) = std::fs::read_dir(output_path) { + println!("\nOutput files:"); + for entry in entries.flatten() { + if let Some(name) = entry.file_name().to_str() { + if name.ends_with(".csv") { + println!(" - {}", name); + } + } + } + } + + // Provide instructions for analysis + println!("\n=== Next Steps ==="); + println!("The CSV files are compatible with CmdStan format and can be read by:"); + println!(" - R: posterior package, bayesplot, etc."); + println!(" - Python: arviz.from_cmdstanpy() or pandas.read_csv()"); + println!(" - Stan ecosystem tools"); + println!("\nExample usage in Python:"); + println!(" import pandas as pd"); + println!(" import arviz as az"); + println!(" # Read individual chains"); + println!(" chain0 = pd.read_csv('{}/chain_0.csv')", output_path); + println!(" # Or use arviz to read all chains"); + println!(" # (Note: arviz.from_cmdstanpy might need adaptation)"); + println!("\nExample usage in R:"); + println!(" library(posterior)"); + println!(" draws <- read_cmdstan_csv(c("); + for i in 0..num_chains { + let comma = if i == num_chains - 1 { "" } else { "," }; + println!(" '{}/chain_{}.csv'{}", output_path, i, comma); + } + println!(" ))"); + println!(" summarise_draws(draws)"); + break; + } + + // Timeout - sampler is still running, show progress + SamplerWaitResult::Timeout(mut sampler_) => { + num_progress_updates += 1; + println!("Progress update {}:", num_progress_updates); + + // Get current progress from all chains + let progress = sampler_.progress()?; + for (i, chain) in progress.iter().enumerate() { + let phase = if chain.tuning { "warmup" } else { "sampling" }; + println!( + " Chain {}: {} samples ({} divergences), step size: {:.6} [{}]", + i, chain.finished_draws, chain.divergences, chain.step_size, phase + ); + } + println!(); // Add blank line for readability + + sampler = Some(sampler_); + } + + // An error occurred during sampling + SamplerWaitResult::Err(err, _) => { + eprintln!("✗ Sampling failed: {}", err); + return Err(err); + } + } + } + + Ok(()) +} diff --git a/examples/hashmap_storage.rs b/examples/hashmap_storage.rs new file mode 100644 index 0000000..f247bbd --- /dev/null +++ b/examples/hashmap_storage.rs @@ -0,0 +1,288 @@ +//! HashMap storage implementation example for MCMC traces +use std::{ + f64, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use nuts_rs::{ + CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, HashMapConfig, LogpError, Model, + Sampler, SamplerWaitResult, +}; +use nuts_storable::HasDims; +use rand::Rng; +use thiserror::Error; + +// A simple multivariate normal distribution example +#[derive(Clone, Debug)] +struct MultivariateNormal { + mean: Vec, + precision: Vec>, +} + +impl MultivariateNormal { + fn new(mean: Vec, precision: Vec>) -> Self { + Self { mean, precision } + } +} + +// Custom LogpError implementation +#[allow(dead_code)] +#[derive(Debug, Error)] +enum MyLogpError { + #[error("Recoverable error in logp calculation: {0}")] + Recoverable(String), + #[error("Non-recoverable error in logp calculation: {0}")] + NonRecoverable(String), +} + +impl LogpError for MyLogpError { + fn is_recoverable(&self) -> bool { + matches!(self, MyLogpError::Recoverable(_)) + } +} + +// Implementation of the model's logp function +#[derive(Clone)] +struct MvnLogp { + model: MultivariateNormal, +} + +impl HasDims for MvnLogp { + fn dim_sizes(&self) -> std::collections::HashMap { + std::collections::HashMap::from([ + ( + "unconstrained_parameter".to_string(), + self.model.mean.len() as u64, + ), + ("dim".to_string(), self.model.mean.len() as u64), + ]) + } +} + +impl CpuLogpFunc for MvnLogp { + type LogpError = MyLogpError; + type FlowParameters = (); + type ExpandedVector = Vec; + + fn dim(&self) -> usize { + self.model.mean.len() + } + + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let n = x.len(); + // Compute (x - mean) + let mut diff = vec![0.0; n]; + for i in 0..n { + diff[i] = x[i] - self.model.mean[i]; + } + + let mut quad = 0.0; + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff + for i in 0..n { + // Compute i-th component of P * diff + let mut pdot = 0.0; + for j in 0..n { + let pij = self.model.precision[i][j]; + pdot += pij * diff[j]; + quad += diff[i] * pij * diff[j]; + } + // gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is - (P * diff)_i + grad[i] = -pdot; + } + + Ok(-0.5 * quad) + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + // Simply return the parameter values + Ok(array.to_vec()) + } +} + +struct MvnModel { + math: CpuMath, +} + +/// Implementation of McmcModel for the HashMap backend +impl Model for MvnModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + /// Generate random initial positions for the chain + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + // Initialize position randomly in [-2, 2] + for p in position.iter_mut() { + *p = rng.random_range(-2.0..2.0); + } + Ok(()) + } +} + +fn main() -> Result<()> { + // Create a 2D multivariate normal distribution + let mean = vec![0.0, 0.0]; + let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]]; + let mvn = MultivariateNormal::new(mean, precision); + + // Number of chains + let num_chains = 4; + + // Configure number of draws + let num_tune = 100; + let num_draws = 200; + + // Configure MCMC settings + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = num_chains as _; + settings.num_tune = num_tune; + settings.num_draws = num_draws as _; + settings.seed = 42; + + let model = MvnModel { + math: CpuMath::new(MvnLogp { model: mvn }), + }; + + // Create a new sampler with 4 threads + let start = Instant::now(); + let trace_config = HashMapConfig::new(); + let mut sampler = Some(Sampler::new(model, settings, trace_config, 4, None)?); + + let mut num_progress_updates = 0; + // Interleave progress updates with wait_timeout + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(Duration::from_millis(50)) { + SamplerWaitResult::Trace(traces) => { + println!("Sampling completed in {:?}", start.elapsed()); + + // Process the HashMap results + println!("\nProcessing HashMap storage results:"); + println!("Number of chains: {}", traces.len()); + + for (chain_idx, chain_result) in traces.iter().enumerate() { + println!("\nChain {}:", chain_idx); + + // Print stats information + println!(" Sampler stats variables:"); + for (name, values) in &chain_result.stats { + match values { + nuts_rs::HashMapValue::F64(vec) => { + println!(" {}: {} samples (f64)", name, vec.len()); + if !vec.is_empty() { + println!(" First 5: {:?}", &vec[..vec.len().min(5)]); + } + } + nuts_rs::HashMapValue::Bool(vec) => { + println!(" {}: {} samples (bool)", name, vec.len()); + if !vec.is_empty() { + println!(" First 5: {:?}", &vec[..vec.len().min(5)]); + } + } + _ => println!(" {}: {} (other type)", name, "unknown length"), + } + } + + // Print draws information + println!(" Draw variables:"); + for (name, values) in &chain_result.draws { + match values { + nuts_rs::HashMapValue::F64(vec) => { + println!(" {}: {} scalar draws", name, vec.len()); + if !vec.is_empty() { + println!(" First 5: {:?}", &vec[..vec.len().min(5)]); + if *name == "theta" && vec.len() >= 6 { + // For multidimensional parameters stored as flattened arrays + // Assume 2D parameter, so every 2 values form one draw + println!(" Parameter structure (assuming 2D):"); + for i in (0..vec.len().min(10)).step_by(2) { + if i + 1 < vec.len() { + println!( + " Draw {}: [{:.4}, {:.4}]", + i / 2, + vec[i], + vec[i + 1] + ); + } + } + } + } + } + nuts_rs::HashMapValue::F32(vec) => { + println!(" {}: {} f32 draws", name, vec.len()); + } + nuts_rs::HashMapValue::Bool(vec) => { + println!(" {}: {} bool draws", name, vec.len()); + } + nuts_rs::HashMapValue::I64(vec) => { + println!(" {}: {} i64 draws", name, vec.len()); + } + nuts_rs::HashMapValue::U64(vec) => { + println!(" {}: {} u64 draws", name, vec.len()); + } + nuts_rs::HashMapValue::String(vec) => { + println!(" {}: {} string draws", name, vec.len()); + } + } + } + + // Calculate some basic statistics for theta parameter + if let Some(nuts_rs::HashMapValue::F64(param_samples)) = + chain_result.draws.get("theta") + { + if param_samples.len() >= 2 { + // Assuming 2D parameter stored as flattened array: [x0, y0, x1, y1, ...] + let x_values: Vec = + param_samples.iter().step_by(2).cloned().collect(); + let y_values: Vec = + param_samples.iter().skip(1).step_by(2).cloned().collect(); + + if !x_values.is_empty() { + let mean_x = x_values.iter().sum::() / x_values.len() as f64; + println!(" theta[0] (x-component) mean: {:.4}", mean_x); + } + if !y_values.is_empty() { + let mean_y = y_values.iter().sum::() / y_values.len() as f64; + println!(" theta[1] (y-component) mean: {:.4}", mean_y); + } + } + } + } + break; + } + SamplerWaitResult::Timeout(mut sampler_) => { + // Request progress update + if num_progress_updates < 10 { + // Limit progress updates + println!("Progress update {}", num_progress_updates + 1); + let progress = sampler_.progress()?; + for (i, chain) in progress.iter().enumerate() { + println!( + "Chain {}: {} samples ({} divergences), step size: {:.6}", + i, chain.finished_draws, chain.divergences, chain.step_size + ); + } + } + sampler = Some(sampler_); + num_progress_updates += 1; + } + SamplerWaitResult::Err(err, _) => { + return Err(err); + } + } + } + + println!("\nHashMap storage example completed!"); + println!("The results are stored in memory as HashMaps and can be easily processed in Rust."); + + Ok(()) +} diff --git a/examples/ndarray_storage.rs b/examples/ndarray_storage.rs new file mode 100644 index 0000000..e7e789e --- /dev/null +++ b/examples/ndarray_storage.rs @@ -0,0 +1,305 @@ +//! ndarray storage implementation example for MCMC traces +use std::{ + collections::HashMap, + f64, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use nuts_rs::{ + CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, NdarrayConfig, + Sampler, SamplerWaitResult, +}; +use nuts_storable::HasDims; +use rand::Rng; +use thiserror::Error; + +// A simple multivariate normal distribution example +#[derive(Clone, Debug)] +struct MultivariateNormal { + mean: Vec, + precision: Vec>, +} + +impl MultivariateNormal { + fn new(mean: Vec, precision: Vec>) -> Self { + Self { mean, precision } + } +} + +// Custom LogpError implementation +#[allow(dead_code)] +#[derive(Debug, Error)] +enum MyLogpError { + #[error("Recoverable error in logp calculation: {0}")] + Recoverable(String), + #[error("Non-recoverable error in logp calculation: {0}")] + NonRecoverable(String), +} + +impl LogpError for MyLogpError { + fn is_recoverable(&self) -> bool { + matches!(self, MyLogpError::Recoverable(_)) + } +} + +// Implementation of the model's logp function +#[derive(Clone)] +struct MvnLogp { + model: MultivariateNormal, +} + +impl HasDims for MvnLogp { + fn dim_sizes(&self) -> HashMap { + HashMap::from([ + ( + "unconstrained_parameter".to_string(), + self.model.mean.len() as u64, + ), + ("dim".to_string(), self.model.mean.len() as u64), + ]) + } +} + +impl CpuLogpFunc for MvnLogp { + type LogpError = MyLogpError; + type FlowParameters = (); + type ExpandedVector = Vec; + + fn dim(&self) -> usize { + self.model.mean.len() + } + + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let n = x.len(); + // Compute (x - mean) + let mut diff = vec![0.0; n]; + for i in 0..n { + diff[i] = x[i] - self.model.mean[i]; + } + + let mut quad = 0.0; + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff + for i in 0..n { + // Compute i-th component of P * diff + let mut pdot = 0.0; + for j in 0..n { + let pij = self.model.precision[i][j]; + pdot += pij * diff[j]; + quad += diff[i] * pij * diff[j]; + } + // gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is - (P * diff)_i + grad[i] = -pdot; + } + + Ok(-0.5 * quad) + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + // Simply return the parameter values + Ok(array.to_vec()) + } +} + +struct MvnModel { + math: CpuMath, +} + +/// Implementation of McmcModel for the ndarray backend +impl Model for MvnModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + /// Generate random initial positions for the chain + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + // Initialize position randomly in [-2, 2] + for p in position.iter_mut() { + *p = rng.random_range(-2.0..2.0); + } + Ok(()) + } +} + +fn main() -> Result<()> { + // Create a 3D multivariate normal distribution for more interesting results + let mean = vec![0.0, 1.0, -0.5]; + let precision = vec![ + vec![2.0, 0.3, 0.1], + vec![0.3, 1.5, -0.2], + vec![0.1, -0.2, 1.0], + ]; + let mvn = MultivariateNormal::new(mean, precision); + + // Number of chains + let num_chains = 3; + + // Configure number of draws + let num_tune = 50; + let num_draws = 100; + + // Configure MCMC settings + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = num_chains as _; + settings.num_tune = num_tune; + settings.num_draws = num_draws as _; + settings.seed = 123; + + let model = MvnModel { + math: CpuMath::new(MvnLogp { model: mvn }), + }; + + // Create a new sampler with 3 threads + let start = Instant::now(); + let trace_config = NdarrayConfig::new(); + let mut sampler = Some(Sampler::new( + model, + settings, + trace_config, + num_chains, + None, + )?); + + let mut num_progress_updates = 0; + // Interleave progress updates with wait_timeout + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(Duration::from_millis(100)) { + SamplerWaitResult::Trace(result) => { + println!("Sampling completed in {:?}", start.elapsed()); + + // Process the ndarray results + println!("\nProcessing ndarray storage results:"); + + // Print stats information + println!("Sampler stats variables:"); + for (name, values) in &result.stats { + match values { + nuts_rs::NdarrayValue::F64(arr) => { + println!(" {}: shape {:?} (f64)", name, arr.shape()); + if arr.len() > 0 { + // Print some sample values from the first chain + if arr.ndim() >= 2 { + let chain_0_view = arr.slice(ndarray::s![0, ..5]); + println!(" Chain 0, first 5 samples: {:?}", chain_0_view); + } + } + } + nuts_rs::NdarrayValue::Bool(arr) => { + println!(" {}: shape {:?} (bool)", name, arr.shape()); + if arr.len() > 0 && arr.ndim() >= 2 { + let chain_0_view = arr.slice(ndarray::s![0, ..5]); + println!(" Chain 0, first 5 samples: {:?}", chain_0_view); + } + } + _ => println!(" {}: shape (other type)", name), + } + } + + // Print draws information + println!("\nDraw variables:"); + for (name, values) in &result.draws { + match values { + nuts_rs::NdarrayValue::F64(arr) => { + println!(" {}: shape {:?} (f64)", name, arr.shape()); + if arr.len() > 0 { + // Print statistics for each parameter dimension + if arr.ndim() == 3 { + // Shape is (chains, draws, parameters) + let num_params = arr.shape()[2]; + for param_idx in 0..num_params { + let param_slice = arr.slice(ndarray::s![.., .., param_idx]); + let mean = param_slice.mean().unwrap_or(f64::NAN); + let std = param_slice.std(0.0); + println!( + " Parameter {}: mean={:.4}, std={:.4}", + param_idx, mean, std + ); + } + + // Print some sample values from each chain + println!(" Sample values from each chain (first 3 draws):"); + for chain_idx in 0..(arr.shape()[0].min(3)) { + let chain_samples = + arr.slice(ndarray::s![chain_idx, ..3, ..]); + println!(" Chain {}: {:?}", chain_idx, chain_samples); + } + } else { + // Just print overall mean if not the expected 3D shape + let mean = arr.mean().unwrap_or(f64::NAN); + println!(" Overall mean: {:.4}", mean); + } + } + } + _ => println!(" {}: (other type)", name), + } + } + + // Demonstrate accessing individual samples + if let Some(nuts_rs::NdarrayValue::F64(theta_arr)) = result.draws.get("theta") { + if theta_arr.ndim() == 3 && theta_arr.shape()[0] > 0 && theta_arr.shape()[1] > 0 + { + println!("\nExample: Accessing specific samples:"); + + // Get the 10th sample from chain 0 + if theta_arr.shape()[1] > 9 { + let sample = theta_arr.slice(ndarray::s![0, 9, ..]); + println!(" Chain 0, sample 10: {:?}", sample); + } + + // Get all samples for parameter 0 from chain 1 + if theta_arr.shape()[0] > 1 { + let param_0_chain_1 = theta_arr.slice(ndarray::s![1, .., 0]); + println!( + " Chain 1, parameter 0, all samples: shape {:?}", + param_0_chain_1.shape() + ); + println!( + " First 5 values: {:?}", + param_0_chain_1.slice(ndarray::s![..5]) + ); + } + } + } + break; + } + SamplerWaitResult::Timeout(mut sampler_) => { + // Request progress update + if num_progress_updates < 5 { + // Limit progress updates + println!("Progress update {}", num_progress_updates + 1); + let progress = sampler_.progress()?; + for (i, chain) in progress.iter().enumerate() { + println!( + "Chain {}: {} samples ({} divergences), step size: {:.6}", + i, chain.finished_draws, chain.divergences, chain.step_size + ); + } + } + sampler = Some(sampler_); + num_progress_updates += 1; + } + SamplerWaitResult::Err(err, _) => { + return Err(err); + } + } + } + + println!("\nndarray storage example completed!"); + println!( + "The results are stored as efficient ndarray structures with shape (chains, draws, parameters)." + ); + println!( + "This format is ideal for numerical analysis and can be easily converted to other formats." + ); + + Ok(()) +} diff --git a/examples/zarr_async_trace.rs b/examples/zarr_async_trace.rs new file mode 100644 index 0000000..7e02402 --- /dev/null +++ b/examples/zarr_async_trace.rs @@ -0,0 +1,331 @@ +//! Zarr async backend example for MCMC trace storage +//! +//! This example demonstrates how to use the nuts-rs library with async Zarr storage +//! for running MCMC sampling on a multivariate normal distribution. It shows: +//! +//! - Setting up a custom probability model +//! - Configuring async Zarr storage for results +//! - Running multiple parallel chains with async I/O +//! - Monitoring progress during sampling +//! - Saving results in ArviZ-compatible format + +use std::{ + collections::HashMap, + f64, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use nuts_rs::{ + CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler, + SamplerWaitResult, Storable, ZarrAsyncConfig, +}; +use nuts_storable::{HasDims, Value}; +use rand::Rng; +use thiserror::Error; +use zarrs_object_store::AsyncObjectStore; + +/// A multivariate normal distribution model +/// +/// This represents a probability distribution with mean μ and precision matrix P, +/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ) +#[derive(Clone, Debug)] +struct MultivariateNormal { + mean: Vec, + precision: Vec>, // Inverse of covariance matrix +} + +impl MultivariateNormal { + fn new(mean: Vec, precision: Vec>) -> Self { + Self { mean, precision } + } +} + +/// Custom error type for log probability calculations +/// +/// MCMC samplers need to distinguish between recoverable errors (like numerical +/// issues that can be handled by rejecting the proposal) and non-recoverable +/// errors (like programming bugs that should stop sampling). +#[allow(dead_code)] +#[derive(Debug, Error)] +enum MyLogpError { + #[error("Recoverable error in logp calculation: {0}")] + Recoverable(String), + #[error("Non-recoverable error in logp calculation: {0}")] + NonRecoverable(String), +} + +impl LogpError for MyLogpError { + fn is_recoverable(&self) -> bool { + matches!(self, MyLogpError::Recoverable(_)) + } +} + +/// Implementation of the log probability function for multivariate normal +/// +/// This struct contains the model parameters and implements the mathematical +/// operations needed for MCMC sampling: computing log probability and gradients. +#[derive(Clone)] +struct MvnLogp { + model: MultivariateNormal, + buffer: Vec, // Temporary buffer for computations +} + +impl HasDims for MvnLogp { + /// Define dimension names and sizes for storage + /// + /// This tells the storage system what array dimensions to expect. + /// These dimensions will be used to structure the output data. + fn dim_sizes(&self) -> HashMap { + HashMap::from([ + // Dimension for the actual parameter vector x + ("x".to_string(), self.model.mean.len() as u64), + ]) + } + + fn coords(&self) -> HashMap { + HashMap::from([( + "x".to_string(), + Value::Strings(vec!["x1".to_string(), "x2".to_string()]), + )]) + } +} + +/// Additional quantities computed from each sample +/// +/// The `Storable` derive macro automatically generates code to store this +/// struct in the trace. The `dims` attribute specifies which dimension +/// each field should use. +#[derive(Storable)] +struct ExpandedDraw { + /// Store the parameter values with dimension "x" + #[storable(dims("x"))] + prec: Vec, + /// A scalar derived quantity (difference between first two parameters) + diff: f64, +} + +impl CpuLogpFunc for MvnLogp { + type LogpError = MyLogpError; + type FlowParameters = (); // No parameter transformations needed + type ExpandedVector = ExpandedDraw; + + /// Return the dimensionality of the parameter space + fn dim(&self) -> usize { + self.model.mean.len() + } + + /// Compute log probability and gradient + /// + /// This is the core mathematical function that MCMC uses to explore + /// the parameter space. It computes both the log probability density + /// and its gradient for efficient sampling with Hamiltonian Monte Carlo. + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let n = x.len(); + + // Compute (x - mean) + let diff = &mut self.buffer; + for i in 0..n { + diff[i] = x[i] - self.model.mean[i]; + } + + let mut quad = 0.0; + + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff + for i in 0..n { + // Compute i-th component of P * diff + let mut pdot = 0.0; + for j in 0..n { + let pij = self.model.precision[i][j]; + pdot += pij * diff[j]; + quad += diff[i] * pij * diff[j]; + } + // Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i + grad[i] = -pdot; + } + + Ok(-0.5 * quad) + } + + /// Compute additional quantities from each sample + /// + /// This function is called for each accepted sample to compute derived + /// quantities that should be stored in the trace. These might be + /// transformed parameters, predictions, or other quantities of interest. + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + // Store the raw parameter values and compute a simple derived quantity + Ok(ExpandedDraw { + prec: array.to_vec(), + diff: array[1] - array[0], // Example: difference between first two parameters + }) + } + + fn vector_coord(&self) -> Option { + Some(Value::Strings(vec!["x1".to_string(), "x2".to_string()])) + } +} + +/// The complete MCMC model +/// +/// This struct implements the Model trait, which is the main interface +/// that samplers use. It provides access to the mathematical operations +/// and handles initialization of the sampling chains. +struct MvnModel { + math: CpuMath, +} + +impl Model for MvnModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + /// Generate random initial positions for the chain + /// + /// Good initialization is important for MCMC efficiency. The starting + /// points should be in a reasonable region of the parameter space + /// where the log probability is finite. + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + // Initialize each parameter randomly in the range [-2, 2] + // For this simple example, this should put us in a reasonable + // region around the mode of the distribution + for p in position.iter_mut() { + *p = rng.random_range(-2.0..2.0); + } + Ok(()) + } +} + +fn main() -> Result<()> { + println!("=== Multivariate Normal MCMC with Async Zarr Storage ===\n"); + + // Create a 2D multivariate normal distribution + // This creates a distribution with mean [0, 0] and precision matrix + // [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation + // between the two parameters + let mean = vec![0.0, 0.0]; + let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]]; + let mvn = MultivariateNormal::new(mean, precision); + + println!("Model: 2D Multivariate Normal"); + println!("Mean: {:?}", mvn.mean); + println!("Precision matrix: {:?}\n", mvn.precision); + + // Configure output location + let output_path = "mcmc_output/async_trace.zarr"; + println!("Output will be saved to: {}\n", output_path); + + // Sampling configuration + let num_chains = 4; // Run 4 parallel chains for better convergence assessment + let num_tune = 500; // Warmup samples to tune the sampler + let num_draws = 500; // Post-warmup samples to keep + + println!("Sampling configuration:"); + println!(" Chains: {}", num_chains); + println!(" Warmup samples: {}", num_tune); + println!(" Sampling draws: {}", num_draws); + + // Configure MCMC settings + // DiagGradNutsSettings provides sensible defaults for the NUTS sampler + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = num_chains as _; + settings.num_tune = num_tune; + settings.num_draws = num_draws as _; + settings.seed = 54; // For reproducible results + + let path = std::path::Path::new(output_path).canonicalize()?; + let object_store = object_store::local::LocalFileSystem::new_with_prefix(path)?; + let store = Arc::new(AsyncObjectStore::new(object_store)); + + // Create the model instance + let model = MvnModel { + math: CpuMath::new(MvnLogp { + model: mvn, + buffer: vec![0.0; 2], + }), + }; + + // Start sampling + println!("\nStarting MCMC sampling with async Zarr backend...\n"); + let start = Instant::now(); + + // Configure async Zarr storage with default settings + // This uses async I/O operations to avoid blocking during writes + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + let handle = rt.handle().clone(); + let zarr_async_config = ZarrAsyncConfig::new(handle, store.clone()); + + // Create sampler with 4 worker threads + // The sampler runs asynchronously, so we can monitor progress + let mut sampler = Some(Sampler::new(model, settings, zarr_async_config, 4, None)?); + + let mut num_progress_updates = 0; + + // Main sampling loop with progress monitoring + // This demonstrates how to monitor long-running sampling jobs + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(Duration::from_millis(50)) { + // Sampling completed successfully + SamplerWaitResult::Trace(_) => { + println!("✓ Async sampling completed in {:?}", start.elapsed()); + println!("✓ Traces written to Zarr format at '{}'", output_path); + + // Provide instructions for analysis + println!("\n=== Next Steps ==="); + println!("To analyze results in Python with ArviZ:"); + println!(" import arviz as az"); + println!(" data = az.from_zarr('{}')", output_path); + println!(" az.plot_trace(data)"); + println!(" az.summary(data)"); + println!("\nThe async Zarr format contains:"); + println!(" - posterior/: Main sampling results"); + println!(" - sample_stats/: Sampler diagnostics"); + println!(" - warmup_*: Warmup phase results"); + println!("\nNote: The async backend uses tokio tasks for I/O operations,"); + println!(" which can improve performance by avoiding blocking writes."); + break; + } + + // Timeout - sampler is still running, show progress + SamplerWaitResult::Timeout(mut sampler_) => { + num_progress_updates += 1; + println!("Progress update {} (async I/O):", num_progress_updates); + + // Get current progress from all chains + let progress = sampler_.progress()?; + for (i, chain) in progress.iter().enumerate() { + let phase = if chain.tuning { "warmup" } else { "sampling" }; + println!( + " Chain {}: {} samples ({} divergences), step size: {:.6} [{}]", + i, chain.finished_draws, chain.divergences, chain.step_size, phase + ); + } + println!(" (Zarr writes are happening asynchronously in the background)"); + println!(); // Add blank line for readability + + sampler = Some(sampler_); + } + + // An error occurred during sampling + SamplerWaitResult::Err(err, _) => { + eprintln!("✗ Async sampling failed: {}", err); + return Err(err); + } + } + } + + Ok(()) +} diff --git a/examples/zarr_trace.rs b/examples/zarr_trace.rs new file mode 100644 index 0000000..c569dd6 --- /dev/null +++ b/examples/zarr_trace.rs @@ -0,0 +1,322 @@ +//! Zarr backend example for MCMC trace storage +//! +//! This example demonstrates how to use the nuts-rs library with Zarr storage +//! for running MCMC sampling on a multivariate normal distribution. It shows: +//! +//! - Setting up a custom probability model +//! - Configuring Zarr storage for results +//! - Running multiple parallel chains +//! - Monitoring progress during sampling +//! - Saving results in ArviZ-compatible format + +use std::{ + collections::HashMap, + f64, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use nuts_rs::{ + CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler, + SamplerWaitResult, Storable, ZarrConfig, +}; +use nuts_storable::{HasDims, Value}; +use rand::Rng; +use thiserror::Error; +use zarrs::filesystem::FilesystemStore; + +/// A multivariate normal distribution model +/// +/// This represents a probability distribution with mean μ and precision matrix P, +/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ) +#[derive(Clone, Debug)] +struct MultivariateNormal { + mean: Vec, + precision: Vec>, // Inverse of covariance matrix +} + +impl MultivariateNormal { + fn new(mean: Vec, precision: Vec>) -> Self { + Self { mean, precision } + } +} + +/// Custom error type for log probability calculations +/// +/// MCMC samplers need to distinguish between recoverable errors (like numerical +/// issues that can be handled by rejecting the proposal) and non-recoverable +/// errors (like programming bugs that should stop sampling). +#[allow(dead_code)] +#[derive(Debug, Error)] +enum MyLogpError { + #[error("Recoverable error in logp calculation: {0}")] + Recoverable(String), + #[error("Non-recoverable error in logp calculation: {0}")] + NonRecoverable(String), +} + +impl LogpError for MyLogpError { + fn is_recoverable(&self) -> bool { + matches!(self, MyLogpError::Recoverable(_)) + } +} + +/// Implementation of the log probability function for multivariate normal +/// +/// This struct contains the model parameters and implements the mathematical +/// operations needed for MCMC sampling: computing log probability and gradients. +#[derive(Clone)] +struct MvnLogp { + model: MultivariateNormal, + buffer: Vec, // Temporary buffer for computations +} + +impl HasDims for MvnLogp { + /// Define dimension names and sizes for storage + /// + /// This tells the storage system what array dimensions to expect. + /// These dimensions will be used to structure the output data. + fn dim_sizes(&self) -> HashMap { + HashMap::from([ + // Dimension for the actual parameter vector x + ("x".to_string(), self.model.mean.len() as u64), + ]) + } + + fn coords(&self) -> HashMap { + HashMap::from([( + "x".to_string(), + Value::Strings(vec!["x1".to_string(), "x2".to_string()]), + )]) + } +} + +/// Additional quantities computed from each sample +/// +/// The `Storable` derive macro automatically generates code to store this +/// struct in the trace. The `dims` attribute specifies which dimension +/// each field should use. +#[derive(Storable)] +struct ExpandedDraw { + /// Store the parameter values with dimension "x" + #[storable(dims("x"))] + prec: Vec, + /// A scalar derived quantity (difference between first two parameters) + diff: f64, +} + +impl CpuLogpFunc for MvnLogp { + type LogpError = MyLogpError; + type FlowParameters = (); // No parameter transformations needed + type ExpandedVector = ExpandedDraw; + + /// Return the dimensionality of the parameter space + fn dim(&self) -> usize { + self.model.mean.len() + } + + /// Compute log probability and gradient + /// + /// This is the core mathematical function that MCMC uses to explore + /// the parameter space. It computes both the log probability density + /// and its gradient for efficient sampling with Hamiltonian Monte Carlo. + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let n = x.len(); + + // Compute (x - mean) + let diff = &mut self.buffer; + for i in 0..n { + diff[i] = x[i] - self.model.mean[i]; + } + + let mut quad = 0.0; + + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff + for i in 0..n { + // Compute i-th component of P * diff + let mut pdot = 0.0; + for j in 0..n { + let pij = self.model.precision[i][j]; + pdot += pij * diff[j]; + quad += diff[i] * pij * diff[j]; + } + // Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i + grad[i] = -pdot; + } + + Ok(-0.5 * quad) + } + + /// Compute additional quantities from each sample + /// + /// This function is called for each accepted sample to compute derived + /// quantities that should be stored in the trace. These might be + /// transformed parameters, predictions, or other quantities of interest. + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + // Store the raw parameter values and compute a simple derived quantity + Ok(ExpandedDraw { + prec: array.to_vec(), + diff: array[1] - array[0], // Example: difference between first two parameters + }) + } + + fn vector_coord(&self) -> Option { + Some(Value::Strings(vec!["x1".to_string(), "x2".to_string()])) + } +} + +/// The complete MCMC model +/// +/// This struct implements the Model trait, which is the main interface +/// that samplers use. It provides access to the mathematical operations +/// and handles initialization of the sampling chains. +struct MvnModel { + math: CpuMath, +} + +impl Model for MvnModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + /// Generate random initial positions for the chain + /// + /// Good initialization is important for MCMC efficiency. The starting + /// points should be in a reasonable region of the parameter space + /// where the log probability is finite. + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + // Initialize each parameter randomly in the range [-2, 2] + // For this simple example, this should put us in a reasonable + // region around the mode of the distribution + for p in position.iter_mut() { + *p = rng.random_range(-2.0..2.0); + } + Ok(()) + } +} + +fn main() -> Result<()> { + println!("=== Multivariate Normal MCMC with Zarr Storage ===\n"); + + // Create a 2D multivariate normal distribution + // This creates a distribution with mean [0, 0] and precision matrix + // [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation + // between the two parameters + let mean = vec![0.0, 0.0]; + let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]]; + let mvn = MultivariateNormal::new(mean, precision); + + println!("Model: 2D Multivariate Normal"); + println!("Mean: {:?}", mvn.mean); + println!("Precision matrix: {:?}\n", mvn.precision); + + // Configure output location + let output_path = "mcmc_output/trace.zarr"; + println!("Output will be saved to: {}\n", output_path); + + // Sampling configuration + let num_chains = 4; // Run 4 parallel chains for better convergence assessment + let num_tune = 500; // Warmup samples to tune the sampler + let num_draws = 500; // Post-warmup samples to keep + + println!("Sampling configuration:"); + println!(" Chains: {}", num_chains); + println!(" Warmup samples: {}", num_tune); + println!(" Sampling draws: {}", num_draws); + + // Configure MCMC settings + // DiagGradNutsSettings provides sensible defaults for the NUTS sampler + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = num_chains as _; + settings.num_tune = num_tune; + settings.num_draws = num_draws as _; + settings.seed = 54; // For reproducible results + + // Set up Zarr storage + // FilesystemStore writes to a directory on disk in Zarr format + let store: zarrs::storage::ReadableWritableListableStorage = + Arc::new(FilesystemStore::new(&output_path)?); + + // Create the model instance + let model = MvnModel { + math: CpuMath::new(MvnLogp { + model: mvn, + buffer: vec![0.0; 2], + }), + }; + + // Start sampling + println!("\nStarting MCMC sampling...\n"); + let start = Instant::now(); + + // Configure Zarr storage with default settings + let zarr_config = ZarrConfig::new(store.clone()); + + // Create sampler with 4 worker threads + // The sampler runs asynchronously, so we can monitor progress + let mut sampler = Some(Sampler::new(model, settings, zarr_config, 4, None)?); + + let mut num_progress_updates = 0; + + // Main sampling loop with progress monitoring + // This demonstrates how to monitor long-running sampling jobs + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(Duration::from_millis(50)) { + // Sampling completed successfully + SamplerWaitResult::Trace(_) => { + println!("✓ Sampling completed in {:?}", start.elapsed()); + println!("✓ Traces written to Zarr format at '{}'", output_path); + + // Provide instructions for analysis + println!("\n=== Next Steps ==="); + println!("To analyze results in Python with ArviZ:"); + println!(" import arviz as az"); + println!(" data = az.from_zarr('{}')", output_path); + println!(" az.plot_trace(data)"); + println!(" az.summary(data)"); + println!("\nThe Zarr format contains:"); + println!(" - posterior/: Main sampling results"); + println!(" - sample_stats/: Sampler diagnostics"); + println!(" - warmup_*: Warmup phase results"); + break; + } + + // Timeout - sampler is still running, show progress + SamplerWaitResult::Timeout(mut sampler_) => { + num_progress_updates += 1; + println!("Progress update {}:", num_progress_updates); + + // Get current progress from all chains + let progress = sampler_.progress()?; + for (i, chain) in progress.iter().enumerate() { + let phase = if chain.tuning { "warmup" } else { "sampling" }; + println!( + " Chain {}: {} samples ({} divergences), step size: {:.6} [{}]", + i, chain.finished_draws, chain.divergences, chain.step_size, phase + ); + } + println!(); // Add blank line for readability + + sampler = Some(sampler_); + } + + // An error occurred during sampling + SamplerWaitResult::Err(err, _) => { + eprintln!("✗ Sampling failed: {}", err); + return Err(err); + } + } + } + + Ok(()) +} diff --git a/nuts-derive/Cargo.toml b/nuts-derive/Cargo.toml new file mode 100644 index 0000000..bdd7f3a --- /dev/null +++ b/nuts-derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "nuts-derive" +version = "0.1.0" +edition = "2024" + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "1.0", features = ["full"] } +nuts-storable = { path = "../nuts-storable" } + +[lib] +proc-macro = true diff --git a/nuts-derive/src/lib.rs b/nuts-derive/src/lib.rs new file mode 100644 index 0000000..5cd34e5 --- /dev/null +++ b/nuts-derive/src/lib.rs @@ -0,0 +1,559 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::{ToTokens, quote}; +use syn::{ + AngleBracketedGenericArguments, Data, DeriveInput, Fields, GenericParam, Ident, Lit, LitStr, + PathArguments, Token, Type, TypePath, + parse::{Parse, ParseStream}, + parse_macro_input, + punctuated::Punctuated, +}; + +// Helper struct to parse `#[storable(dims(...))]` or #[storable(flattened)] +enum StorableAttr { + Item(Vec), + Flattened(), + Ignore(), +} + +impl Parse for StorableAttr { + fn parse(input: ParseStream) -> syn::Result { + let metas = Punctuated::::parse_terminated(input)?; + + for meta in metas { + match meta { + syn::Meta::List(list) => { + if list.path.is_ident("dims") { + return Ok(StorableAttr::Item( + list.nested + .into_iter() + .map(|e| match e { + syn::NestedMeta::Lit(Lit::Str(s)) => Ok(s), + _ => Err(syn::Error::new_spanned(e, "Expected string literal")), + }) + .collect::, _>>()?, + )); + } + } + syn::Meta::Path(path) => { + if path.is_ident("flatten") { + return Ok(StorableAttr::Flattened()); + } + if path.is_ident("ignore") { + return Ok(StorableAttr::Ignore()); + } + } + _ => { + return Err(syn::Error::new_spanned( + meta, + "Unsupported storable attribute. Expected `dims(...)` or `flatten`", + )); + } + } + } + + Ok(StorableAttr::Item(vec![])) + } +} + +struct StorableBasicField { + name: Ident, + item_type: proc_macro2::TokenStream, + is_vec: bool, + is_option: bool, + dims: Vec, +} + +struct StorableInnerField { + name: Ident, + item_type: proc_macro2::TokenStream, + is_option: bool, +} + +enum StorableField { + Basic(StorableBasicField), + Inner(StorableInnerField), + Generic(StorableInnerField), +} + +// Check if a type is a generic type parameter +fn is_generic_param(ty: &Type, generics: &syn::Generics) -> bool { + if let Type::Path(type_path) = ty { + if type_path.path.segments.len() == 1 { + let type_name = &type_path.path.segments.first().unwrap().ident; + return generics.params.iter().any(|param| { + if let GenericParam::Type(type_param) = param { + &type_param.ident == type_name + } else { + false + } + }); + } + } + false +} + +// Check if a type implements Storable trait based on bounds +fn has_storable_bound(ty: &Ident, generics: &syn::Generics) -> bool { + for param in &generics.params { + if let GenericParam::Type(type_param) = param { + if &type_param.ident == ty { + for bound in &type_param.bounds { + if let syn::TypeParamBound::Trait(trait_bound) = bound { + let path = &trait_bound.path; + if path.segments.len() == 1 + && path.segments.first().unwrap().ident == "Storable" + { + return true; + } + } + } + } + } + } + false +} + +#[proc_macro_derive(Storable, attributes(storable))] +pub fn storable_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let name = &ast.ident; + let generics = &ast.generics; + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let impl_generics = if generics.params.is_empty() { + quote! { } + } else { + quote! { #impl_generics } + }; + + let fields = if let Data::Struct(s) = ast.data { + if let Fields::Named(fields) = s.fields { + fields.named + } else { + panic!("Storable can only be derived for structs with named fields"); + } + } else { + panic!("Storable can only be derived on structs"); + }; + + let mut storable_fields = Vec::new(); + for field in fields { + let field_name = field.ident.clone().unwrap(); + let ty = &field.ty; + let ty_str = quote!(#ty).to_string(); + + let attr = field + .attrs + .iter() + .find(|a| a.path.is_ident("storable")) + .map(|a| a.parse_args::().unwrap()); + + if let Some(StorableAttr::Ignore()) = attr { + continue; // Skip this field + } + + let attr = attr.unwrap_or(StorableAttr::Item(vec![])); + + if let StorableAttr::Flattened() = attr { + let path = if let Type::Path(TypePath { path: p, qself: _ }) = ty { + p + } else { + panic!( + "Unsupported field type with flattened attribute: {}", + ty_str + ); + }; + let item = if path.segments.first().unwrap().ident.to_string() == "Option" { + if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { + args, .. + }) = &path.segments.first().unwrap().arguments + { + if let Some(arg) = args.first() { + let inner_type = quote!(#arg); + StorableField::Inner(StorableInnerField { + name: field_name.clone(), + item_type: inner_type, + is_option: true, + }) + } else { + panic!("Invalid Option type for flattened field"); + } + } else { + panic!("Invalid Option type for flattened field"); + } + } else { + StorableField::Inner(StorableInnerField { + name: field_name.clone(), + item_type: path.into_token_stream(), + is_option: false, + }) + }; + storable_fields.push(item); + continue; + } + + let dims = if let StorableAttr::Item(dims) = attr { + dims + } else { + vec![] + }; + + // Check if the field is a generic type parameter + if let Type::Path(type_path) = ty { + if type_path.path.segments.len() == 1 { + let type_name = &type_path.path.segments.first().unwrap().ident; + + // Check if this is a generic type parameter with Storable bound + if is_generic_param(ty, generics) && has_storable_bound(type_name, generics) { + storable_fields.push(StorableField::Generic(StorableInnerField { + name: field_name, + item_type: quote!(#type_name), + is_option: false, + })); + continue; + } + + // Check if this is Option where T is a generic type parameter + if type_name == "Option" { + if let PathArguments::AngleBracketed(args) = + &type_path.path.segments.first().unwrap().arguments + { + if let Some(arg) = args.args.first() { + if let syn::GenericArgument::Type(inner_ty) = arg { + if let Type::Path(inner_path) = inner_ty { + if inner_path.path.segments.len() == 1 { + let inner_name = + &inner_path.path.segments.first().unwrap().ident; + if is_generic_param(inner_ty, generics) + && has_storable_bound(inner_name, generics) + { + storable_fields.push(StorableField::Generic( + StorableInnerField { + name: field_name, + item_type: quote!(#inner_name), + is_option: true, + }, + )); + continue; + } + } + } + } + } + } + } + } + } + + let item = match ty_str.as_str() { + "u64" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::U64 }, + is_vec: false, + is_option: false, + dims, + }), + "i64" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::I64 }, + is_vec: false, + is_option: false, + dims, + }), + "f64" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F64 }, + is_vec: false, + is_option: false, + dims, + }), + "f32" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F32 }, + is_vec: false, + is_option: false, + dims, + }), + "bool" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::Bool }, + is_vec: false, + is_option: false, + dims, + }), + "Option < u64 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::U64 }, + is_vec: false, + is_option: true, + dims, + }), + "Option < i64 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::I64 }, + is_vec: false, + is_option: true, + dims, + }), + "Option < f64 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F64 }, + is_vec: false, + is_option: true, + dims, + }), + "Option < f32 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F32 }, + is_vec: false, + is_option: true, + dims, + }), + "Option < bool >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::Bool }, + is_vec: false, + is_option: true, + dims, + }), + "Vec < u64 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::U64 }, + is_vec: true, + is_option: false, + dims, + }), + "Vec < i64 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::I64 }, + is_vec: true, + is_option: false, + dims, + }), + "Vec < f64 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F64 }, + is_vec: true, + is_option: false, + dims, + }), + "Vec < f32 >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F32 }, + is_vec: true, + is_option: false, + dims, + }), + "Vec < bool >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::Bool }, + is_vec: true, + is_option: false, + dims, + }), + "Option < Vec < u64 > >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::U64 }, + is_vec: true, + is_option: true, + dims, + }), + "Option < Vec < i64 > >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::I64 }, + is_vec: true, + is_option: true, + dims, + }), + "Option < Vec < f64 > >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F64 }, + is_vec: true, + is_option: true, + dims, + }), + "Option < Vec < f32 > >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::F32 }, + is_vec: true, + is_option: true, + dims, + }), + "Option< Vec < bool > >" => StorableField::Basic(StorableBasicField { + name: field_name.clone(), + item_type: quote! { nuts_storable::ItemType::Bool }, + is_vec: true, + is_option: true, + dims, + }), + _ => { + // Attempt to handle complex generic types that are still Storable + if let Type::Path(type_path) = ty { + // Check if it's a type that has the Storable trait + let type_token = quote!(#type_path); + storable_fields.push(StorableField::Inner(StorableInnerField { + name: field_name.clone(), + item_type: type_token, + is_option: false, + })); + continue; + } else { + panic!("Unsupported field type: {}", ty_str); + } + } + }; + storable_fields.push(item); + } + + let names_exprs = storable_fields.iter().map(|f| match f { + StorableField::Basic(field) => { + let name = field.name.to_string(); + quote! { vec![#name] } + } + StorableField::Inner(field) => { + let item_type = &field.item_type; + quote! { #item_type::names(parent) } + } + StorableField::Generic(field) => { + let name = field.name.to_string(); + if field.is_option { + quote! { vec![#name] } + } else { + let item_type = &field.item_type; + quote! { #item_type::names(parent) } + } + } + }); + + let names_fn = quote! { + fn names(parent: &P) -> Vec<&str> { + let mut names = Vec::new(); + #(names.extend(#names_exprs);)* + names + } + }; + + let item_type_arms = storable_fields.iter().map(|f| match f { + StorableField::Basic(field) => { + let name_str = field.name.to_string(); + let item_type = &field.item_type; + quote! { #name_str => #item_type, } + } + StorableField::Inner(field) => { + let item_type = &field.item_type; + quote! { name if #item_type::names(parent).contains(&name) => #item_type::item_type(parent, name), } + } + StorableField::Generic(field) => { + let name_str = field.name.to_string(); + let item_type = &field.item_type; + if field.is_option { + quote! { #name_str => nuts_storable::ItemType::Generic, } + } else { + quote! { name if #item_type::names(parent).contains(&name) => #item_type::item_type(parent, name), } + } + } + }); + + let item_type_fn = quote! { + fn item_type(parent: &P, item: &str) -> nuts_storable::ItemType { + match item { + #(#item_type_arms)* + _ => { panic!("Unknown item: {}", item); } + } + } + }; + + let dims_arms = storable_fields.iter().map(|f| match f { + StorableField::Basic(field) => { + let name_str = field.name.to_string(); + let dims = &field.dims; + quote! { #name_str => vec![#(#dims),*], } + } + StorableField::Inner(field) => { + let item_type = &field.item_type; + quote! { name if #item_type::names(parent).contains(&name) => #item_type::dims(parent, name), } + } + StorableField::Generic(field) => { + let name_str = field.name.to_string(); + let item_type = &field.item_type; + if field.is_option { + quote! { #name_str => vec![], } + } else { + quote! { name if #item_type::names(parent).contains(&name) => #item_type::dims(parent, name), } + } + } + }); + + let dims_fn = quote! { + fn dims<'a>(parent: &'a P, item: &str) -> Vec<&'a str> { + match item { + #(#dims_arms)* + _ => { panic!("Unknown item: {}", item); } + } + } + }; + + let get_all_exprs = storable_fields.iter().map(|f| match f { + StorableField::Basic(field) => { + let name = &field.name; + let name_str = name.to_string(); + let value_expr = if field.is_option { + if field.is_vec { + quote! { self.#name.as_ref().map(|v| nuts_storable::Value::from(v.clone())) } + } else { + quote! { self.#name.map(nuts_storable::Value::from) } + } + } else { + quote! { Some(nuts_storable::Value::from(self.#name.clone())) } + }; + quote! { result.push((#name_str, #value_expr)); } + } + StorableField::Inner(field) => { + let name = &field.name; + if field.is_option { + quote! { + if let Some(inner) = &self.#name { + result.extend(inner.get_all(parent)); + } + } + } else { + quote! { result.extend(self.#name.get_all(parent)); } + } + } + StorableField::Generic(field) => { + let name = &field.name; + if field.is_option { + quote! { + if let Some(inner) = &self.#name { + result.push((#name.to_string().as_str(), Some(nuts_storable::Value::Generic(Box::new(inner.clone()))))); + } else { + result.push((#name.to_string().as_str(), None)); + } + } + } else { + quote! { result.extend(self.#name.get_all(parent)); } + } + } + }); + + let get_all_fn = quote! { + fn get_all(&self, parent: &P) -> Vec<(&str, Option)> { + let mut result = Vec::with_capacity(Self::names(parent).len()); + #(#get_all_exprs)* + result + } + }; + + let r#gen = quote! { + impl #impl_generics nuts_storable::Storable

for #name #ty_generics #where_clause { + #names_fn + #item_type_fn + #dims_fn + #get_all_fn + } + }; + + r#gen.into() +} diff --git a/nuts-derive/tests/storable.rs b/nuts-derive/tests/storable.rs new file mode 100644 index 0000000..fa0bf34 --- /dev/null +++ b/nuts-derive/tests/storable.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; + +use nuts_derive::Storable; +use nuts_storable::{HasDims, Storable}; +use nuts_storable::{ItemType, Value}; + +#[derive(Storable, Clone)] +struct InnerStats { + value: f64, + #[storable(dims("dim"))] + draws: Vec, +} + +#[derive(Storable, Clone)] +struct InnerStats2 { + value2: f64, + #[storable(dims("dim"))] + draws2: Vec, +} + +#[derive(Storable, Clone)] +struct ExampleStats { + step_size: f64, + n_steps: u64, + is_adapting: bool, + #[storable(dims("dim", "dim2"))] + gradients: Vec, + #[storable(dims("dim", "dim2"))] + gradients2: Option>, + #[storable(flatten)] + inner: InnerStats, + #[storable(flatten)] + inner2: Option, + #[storable(ignore)] + _not_stored: String, +} + +#[derive(Storable)] +struct Example2> { + field1: u64, + field2: S, + #[storable(ignore)] + _phantom: std::marker::PhantomData P>, +} + +#[test] +fn test_storable() { + struct Parent {} + + impl nuts_storable::HasDims for Parent { + fn dim_sizes(&self) -> HashMap { + HashMap::from([("dim".to_string(), 3), ("dim2".to_string(), 3)]) + } + } + + let inner = InnerStats { + value: 1.0, + draws: vec![1.0, 2.0, 3.0], + }; + let inner2 = InnerStats2 { + value2: 8.0, + draws2: vec![9.0, 2.0, 3.0], + }; + let stats = ExampleStats { + step_size: 0.1, + n_steps: 10, + is_adapting: true, + gradients: vec![0.1, 0.2, 0.3], + gradients2: None, + inner, + inner2: Some(inner2), + _not_stored: "should not be stored".to_string(), + }; + + let stats2: Example2 = Example2 { + field1: 42, + field2: stats.clone(), + _phantom: std::marker::PhantomData, + }; + + let parent = Parent {}; + + assert_eq!( + ExampleStats::names(&parent), + vec![ + "step_size".to_string(), + "n_steps".to_string(), + "is_adapting".to_string(), + "gradients".to_string(), + "gradients2".to_string(), + "value".to_string(), + "draws".to_string(), + "value2".to_string(), + "draws2".to_string(), + ] + ); + + assert_eq!(ExampleStats::item_type(&parent, "step_size"), ItemType::F64); + assert_eq!(ExampleStats::item_type(&parent, "n_steps"), ItemType::U64); + assert_eq!( + ExampleStats::item_type(&parent, "is_adapting"), + ItemType::Bool + ); + assert_eq!(ExampleStats::item_type(&parent, "gradients"), ItemType::F64); + + assert_eq!(ExampleStats::dims(&parent, "step_size").len(), 0); + assert_eq!(ExampleStats::dims(&parent, "n_steps").len(), 0); + assert_eq!(ExampleStats::dims(&parent, "is_adapting").len(), 0); + assert_eq!( + ExampleStats::dims(&parent, "gradients"), + vec!["dim".to_string(), "dim2".to_string()] + ); + assert_eq!( + ExampleStats::dims(&parent, "draws"), + vec!["dim".to_string()] + ); + + let vals = stats.get_all(&parent); + assert_eq!(vals.len(), 9); + assert_eq!(vals[0].1, Some(Value::ScalarF64(0.1))); + assert_eq!(vals[1].1, Some(Value::ScalarU64(10))); + assert_eq!(vals[2].1, Some(Value::ScalarBool(true))); + assert_eq!(vals[4].1, None); + assert_eq!(vals[7].1, Some(Value::ScalarF64(8.0))); + + assert_eq!( + Example2::<_, ExampleStats>::item_type(&parent, "step_size"), + ItemType::F64 + ); + + let vals2 = stats2.field2.get_all(&parent); + assert_eq!(vals2.len(), 9); +} diff --git a/nuts-storable/Cargo.toml b/nuts-storable/Cargo.toml new file mode 100644 index 0000000..6ff3383 --- /dev/null +++ b/nuts-storable/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "nuts-storable" +version = "0.1.0" +edition = "2024" + +[lib] diff --git a/nuts-storable/src/lib.rs b/nuts-storable/src/lib.rs new file mode 100644 index 0000000..401f45e --- /dev/null +++ b/nuts-storable/src/lib.rs @@ -0,0 +1,229 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ItemType { + U64, + I64, + F64, + F32, + Bool, + String, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Value { + U64(Vec), + I64(Vec), + F64(Vec), + F32(Vec), + Bool(Vec), + ScalarString(String), + ScalarU64(u64), + ScalarI64(i64), + ScalarF64(f64), + ScalarF32(f32), + ScalarBool(bool), + Strings(Vec), +} + +impl From> for Value { + fn from(value: Vec) -> Self { + Value::U64(value) + } +} +impl From> for Value { + fn from(value: Vec) -> Self { + Value::I64(value) + } +} +impl From> for Value { + fn from(value: Vec) -> Self { + Value::F64(value) + } +} +impl From> for Value { + fn from(value: Vec) -> Self { + Value::F32(value) + } +} +impl From> for Value { + fn from(value: Vec) -> Self { + Value::Bool(value) + } +} +impl From for Value { + fn from(value: u64) -> Self { + Value::ScalarU64(value) + } +} +impl From for Value { + fn from(value: i64) -> Self { + Value::ScalarI64(value) + } +} +impl From for Value { + fn from(value: f64) -> Self { + Value::ScalarF64(value) + } +} +impl From for Value { + fn from(value: f32) -> Self { + Value::ScalarF32(value) + } +} +impl From for Value { + fn from(value: bool) -> Self { + Value::ScalarBool(value) + } +} + +pub trait HasDims { + fn dim_sizes(&self) -> HashMap; + fn coords(&self) -> HashMap { + HashMap::new() + } +} + +pub trait Storable: Send + Sync { + fn names(parent: &P) -> Vec<&str>; + fn item_type(parent: &P, item: &str) -> ItemType; + fn dims<'a>(parent: &'a P, item: &str) -> Vec<&'a str>; + + fn get_all(&self, parent: &P) -> Vec<(&str, Option)>; + + fn get_f64(&self, parent: &P, name: &str) -> Option { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::ScalarF64(v)) => Some(v), + _ => None, + }) + } + + fn get_f32(&self, parent: &P, name: &str) -> Option { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::ScalarF32(v)) => Some(v), + _ => None, + }) + } + + fn get_u64(&self, parent: &P, name: &str) -> Option { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::ScalarU64(v)) => Some(v), + _ => None, + }) + } + + fn get_i64(&self, parent: &P, name: &str) -> Option { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::ScalarI64(v)) => Some(v), + _ => None, + }) + } + + fn get_bool(&self, parent: &P, name: &str) -> Option { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::ScalarBool(v)) => Some(v), + _ => None, + }) + } + + fn get_vec_f64(&self, parent: &P, name: &str) -> Option> { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::F64(v)) => Some(v), + _ => None, + }) + } + + fn get_vec_f32(&self, parent: &P, name: &str) -> Option> { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::F32(v)) => Some(v), + _ => None, + }) + } + + fn get_vec_u64(&self, parent: &P, name: &str) -> Option> { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::U64(v)) => Some(v), + _ => None, + }) + } + + fn get_vec_i64(&self, parent: &P, name: &str) -> Option> { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::I64(v)) => Some(v), + _ => None, + }) + } + + fn get_vec_bool(&self, parent: &P, name: &str) -> Option> { + self.get_all(parent) + .into_iter() + .find(|(item_name, _)| *item_name == name) + .and_then(|(_, value)| match value { + Some(Value::Bool(v)) => Some(v), + _ => None, + }) + } +} + +impl Storable

for Vec { + fn names(_parent: &P) -> Vec<&str> { + vec!["value"] + } + + fn item_type(_parent: &P, _item: &str) -> ItemType { + ItemType::F64 + } + + fn dims<'a>(_parent: &'a P, _item: &str) -> Vec<&'a str> { + vec!["dim"] + } + + fn get_all(&self, _parent: &P) -> Vec<(&str, Option)> { + vec![("value", Some(Value::F64(self.clone())))] + } +} + +impl Storable

for () { + fn names(_parent: &P) -> Vec<&str> { + vec![] + } + + fn item_type(_parent: &P, _item: &str) -> ItemType { + panic!("No items in unit type") + } + + fn dims<'a>(_parent: &'a P, _item: &str) -> Vec<&'a str> { + panic!("No items in unit type") + } + + fn get_all(&self, _parent: &P) -> Vec<(&str, Option)> { + vec![] + } +} diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 5b14935..79939e4 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -1,24 +1,22 @@ -use std::{fmt::Debug, marker::PhantomData, ops::Deref}; +use std::{fmt::Debug, marker::PhantomData}; -use arrow::array::StructArray; -use itertools::Itertools; +use nuts_derive::Storable; +use nuts_storable::{HasDims, Storable}; use rand::Rng; +use serde::Serialize; +use super::mass_matrix::MassMatrixAdaptStrategy; +use super::stepsize::AcceptanceRateCollector; +use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy}; use crate::{ + NutsError, chain::AdaptStrategy, euclidean_hamiltonian::EuclideanHamiltonian, hamiltonian::{DivergenceInfo, Hamiltonian, Point}, - mass_matrix_adapt::MassMatrixAdaptStrategy, math_base::Math, nuts::{Collector, NutsOptions}, - sampler::Settings, - sampler_stats::{SamplerStats, StatTraceBuilder}, + sampler_stats::{SamplerStats, StatsDims}, state::State, - stepsize::AcceptanceRateCollector, - stepsize_adapt::{ - DualAverageSettings, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy, - }, - NutsError, }; pub struct GlobalStrategy> { @@ -36,9 +34,9 @@ pub struct GlobalStrategy> { last_update: u64, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize)] pub struct EuclideanAdaptOptions { - pub dual_average_options: DualAverageSettings, + pub step_size_settings: StepSizeSettings, pub mass_matrix_options: S, pub early_window: f64, pub step_size_window: f64, @@ -50,7 +48,7 @@ pub struct EuclideanAdaptOptions { impl Default for EuclideanAdaptOptions { fn default() -> Self { Self { - dual_average_options: DualAverageSettings::default(), + step_size_settings: StepSizeSettings::default(), mass_matrix_options: S::default(), early_window: 0.3, step_size_window: 0.15, @@ -61,23 +59,6 @@ impl Default for EuclideanAdaptOptions { } } -impl> SamplerStats for GlobalStrategy { - type Builder = GlobalStrategyBuilder; - type StatOptions = >::StatOptions; - - fn new_builder( - &self, - options: Self::StatOptions, - settings: &impl Settings, - dim: usize, - ) -> Self::Builder { - GlobalStrategyBuilder { - step_size: SamplerStats::::new_builder(&self.step_size, (), settings, dim), - mass_matrix: self.mass_matrix.new_builder(options, settings, dim), - } - } -} - impl> AdaptStrategy for GlobalStrategy { type Hamiltonian = EuclideanHamiltonian; type Collector = CombinedCollector< @@ -97,7 +78,7 @@ impl> AdaptStrategy for GlobalStrategy assert!(early_end < num_tune); Self { - step_size: StepSizeStrategy::new(options.dual_average_options), + step_size: StepSizeStrategy::new(options.step_size_settings), mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain), options, num_tune, @@ -143,6 +124,8 @@ impl> AdaptStrategy for GlobalStrategy self.step_size.update(&collector.collector1); if draw >= self.num_tune { + // Needed for step size jitter + self.step_size.update_stepsize(rng, hamiltonian, true); self.tuning = false; return Ok(()); } @@ -194,14 +177,14 @@ impl> AdaptStrategy for GlobalStrategy self.step_size .init(math, options, hamiltonian, &position, rng)?; } else { - self.step_size.update_stepsize(hamiltonian, false) + self.step_size.update_stepsize(rng, hamiltonian, false) } return Ok(()); } self.step_size.update_estimator_late(); let is_last = draw == self.num_tune - 1; - self.step_size.update_stepsize(hamiltonian, is_last); + self.step_size.update_stepsize(rng, hamiltonian, is_last); Ok(()) } @@ -221,77 +204,46 @@ impl> AdaptStrategy for GlobalStrategy } } -pub struct GlobalStrategyBuilder { - pub step_size: StepSizeStatsBuilder, - pub mass_matrix: B, +#[derive(Debug, Storable)] +pub struct GlobalStrategyStats, M: Storable

> { + #[storable(flatten)] + pub step_size: S, + #[storable(flatten)] + pub mass_matrix: M, + #[storable(ignore)] + _phantom: std::marker::PhantomData P>, } -impl StatTraceBuilder> for GlobalStrategyBuilder -where - A: MassMatrixAdaptStrategy, -{ - fn append_value(&mut self, math: Option<&mut M>, value: &GlobalStrategy) { - let math = math.expect("Smapler stats need math"); - self.step_size.append_value(Some(math), &value.step_size); - self.mass_matrix - .append_value(Some(math), &value.mass_matrix); - } +#[derive(Debug)] +pub struct GlobalStrategyStatsOptions> { + pub step_size: (), + pub mass_matrix: A::StatsOptions, +} - fn finalize(self) -> Option { - let Self { - step_size, - mass_matrix, - } = self; - match ( - StatTraceBuilder::::finalize(step_size), - mass_matrix.finalize(), - ) { - (None, None) => None, - (Some(stats1), None) => Some(stats1), - (None, Some(stats2)) => Some(stats2), - (Some(stats1), Some(stats2)) => { - let mut data1 = stats1.into_parts(); - let data2 = stats2.into_parts(); - - assert!(data1.2.is_none()); - assert!(data2.2.is_none()); - - let mut fields = data1.0.into_iter().map(|x| x.deref().clone()).collect_vec(); - - fields.extend(data2.0.into_iter().map(|x| x.deref().clone())); - data1.1.extend(data2.1); - - Some(StructArray::new(data1.0, data1.1, None)) - } - } +impl> Clone for GlobalStrategyStatsOptions { + fn clone(&self) -> Self { + *self } +} - fn inspect(&self) -> Option { - let Self { - step_size, - mass_matrix, - } = self; - match ( - StatTraceBuilder::::inspect(step_size), - mass_matrix.inspect(), - ) { - (None, None) => None, - (Some(stats1), None) => Some(stats1), - (None, Some(stats2)) => Some(stats2), - (Some(stats1), Some(stats2)) => { - let mut data1 = stats1.into_parts(); - let data2 = stats2.into_parts(); - - assert!(data1.2.is_none()); - assert!(data2.2.is_none()); - - let mut fields = data1.0.into_iter().map(|x| x.deref().clone()).collect_vec(); - - fields.extend(data2.0.into_iter().map(|x| x.deref().clone())); - data1.1.extend(data2.1); - - Some(StructArray::new(data1.0, data1.1, None)) - } +impl> Copy for GlobalStrategyStatsOptions {} + +impl SamplerStats for GlobalStrategy +where + A: MassMatrixAdaptStrategy, +{ + type Stats = + GlobalStrategyStats>::Stats, A::Stats>; + type StatsOptions = GlobalStrategyStatsOptions; + + fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats { + GlobalStrategyStats { + step_size: { + let _: () = opt.step_size; + self.step_size.extract_stats(math, ()) + }, + mass_matrix: self.mass_matrix.extract_stats(math, opt.mass_matrix), + _phantom: PhantomData, } } } @@ -364,7 +316,10 @@ where #[cfg(test)] pub mod test_logps { + use std::collections::HashMap; + use crate::{cpu_math::CpuLogpFunc, math_base::LogpError}; + use nuts_storable::HasDims; use thiserror::Error; #[derive(Clone, Debug)] @@ -388,9 +343,18 @@ pub mod test_logps { } } + impl HasDims for NormalLogp { + fn dim_sizes(&self) -> HashMap { + vec![("unconstrained_parameter".to_string(), self.dim as u64)] + .into_iter() + .collect() + } + } + impl CpuLogpFunc for NormalLogp { type LogpError = NormalLogpError; - type TransformParams = (); + type FlowParameters = (); + type ExpandedVector = Vec; fn dim(&self) -> usize { self.dim @@ -408,9 +372,20 @@ pub mod test_logps { Ok(logp) } + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result + where + R: rand::Rng + ?Sized, + { + Ok(array.to_vec()) + } + fn inv_transform_normalize( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &[f64], _untransofrmed_gradient: &[f64], _transformed_position: &mut [f64], @@ -421,7 +396,7 @@ pub mod test_logps { fn init_from_transformed_position( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &mut [f64], _untransformed_gradient: &mut [f64], _transformed_position: &[f64], @@ -432,7 +407,7 @@ pub mod test_logps { fn init_from_untransformed_position( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &[f64], _untransformed_gradient: &mut [f64], _transformed_position: &mut [f64], @@ -447,7 +422,7 @@ pub mod test_logps { _untransformed_positions: impl Iterator, _untransformed_gradients: impl Iterator, _untransformed_logp: impl Iterator, - _params: &'a mut Self::TransformParams, + _params: &'a mut Self::FlowParameters, ) -> Result<(), Self::LogpError> { unimplemented!() } @@ -458,13 +433,13 @@ pub mod test_logps { _untransformed_position: &[f64], _untransfogmed_gradient: &[f64], _chain: u64, - ) -> Result { + ) -> Result { unimplemented!() } fn transformation_id( &self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, ) -> Result { unimplemented!() } @@ -476,13 +451,16 @@ mod test { use super::test_logps::NormalLogp; use super::*; use crate::{ - chain::NutsChain, cpu_math::CpuMath, euclidean_hamiltonian::EuclideanHamiltonian, - mass_matrix::DiagMassMatrix, Chain, DiagAdaptExpSettings, + Chain, DiagAdaptExpSettings, + chain::{NutsChain, StatOptions}, + cpu_math::CpuMath, + euclidean_hamiltonian::EuclideanHamiltonian, + mass_matrix::DiagMassMatrix, }; #[test] fn instanciate_adaptive_sampler() { - use crate::mass_matrix_adapt::Strategy; + use crate::mass_matrix::Strategy; let ndim = 10; let func = NormalLogp::new(ndim, 3.); @@ -499,6 +477,7 @@ mod test { EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size); let options = NutsOptions { maxdepth: 10u64, + mindepth: 0, store_gradient: true, store_unconstrained: true, check_turning: true, @@ -511,7 +490,24 @@ mod test { }; let chain = 0u64; - let mut sampler = NutsChain::new(math, hamiltonian, strategy, options, rng, chain); + let stats_options = StatOptions { + adapt: GlobalStrategyStatsOptions { + step_size: (), + mass_matrix: (), + }, + hamiltonian: (), + point: (), + }; + + let mut sampler = NutsChain::new( + math, + hamiltonian, + strategy, + options, + rng, + chain, + stats_options, + ); sampler.set_position(&vec![1.5f64; ndim]).unwrap(); for _ in 0..200 { sampler.draw().unwrap(); diff --git a/src/chain.rs b/src/chain.rs index 441a32e..026312a 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,26 +1,20 @@ use std::{ - cell::RefCell, + cell::{Ref, RefCell}, fmt::Debug, - ops::{Deref, DerefMut}, - sync::Arc, + marker::PhantomData, + ops::DerefMut, }; -use arrow::{ - array::{ - Array, ArrayBuilder, BooleanBuilder, FixedSizeListBuilder, PrimitiveBuilder, StringBuilder, - StructArray, - }, - datatypes::{DataType, Field, Fields, Float64Type, Int64Type, UInt64Type}, -}; +use nuts_storable::{HasDims, Storable}; use rand::Rng; use crate::{ + Math, NutsError, hamiltonian::{Hamiltonian, Point}, - nuts::{draw, Collector, NutsOptions, SampleInfo}, + nuts::{Collector, NutsOptions, SampleInfo, draw}, sampler::Progress, - sampler_stats::{SamplerStats, StatTraceBuilder}, + sampler_stats::{SamplerStats, StatsDims}, state::State, - Math, NutsError, Settings, }; use anyhow::Result; @@ -40,6 +34,10 @@ pub trait Chain: SamplerStats { /// The dimensionality of the posterior. fn dim(&self) -> usize; + + fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)>; + + fn math(&self) -> Ref<'_, M>; } pub struct NutsChain @@ -58,6 +56,7 @@ where draw_count: u64, strategy: A, math: RefCell, + stats_options: StatOptions, } impl NutsChain @@ -73,6 +72,7 @@ where options: NutsOptions, rng: R, chain: u64, + stats_options: StatOptions, ) -> Self { let init = hamiltonian.pool().new_state(&mut math); let collector = strategy.new_collector(&mut math); @@ -87,6 +87,7 @@ where draw_count: 0, strategy, math: math.into(), + stats_options, } } } @@ -124,33 +125,6 @@ pub trait AdaptStrategy: SamplerStats { fn last_num_steps(&self) -> u64; } -impl SamplerStats for NutsChain -where - M: Math, - R: rand::Rng, - A: AdaptStrategy, -{ - type Builder = NutsStatsBuilder; - type StatOptions = StatOptions; - - fn new_builder( - &self, - options: StatOptions, - settings: &impl Settings, - dim: usize, - ) -> Self::Builder { - NutsStatsBuilder::new_with_capacity( - options, - settings, - &self.hamiltonian, - &self.strategy, - self.state.point(), - dim, - &self.options, - ) - } -} - impl Chain for NutsChain where M: Math, @@ -187,15 +161,6 @@ where let mut position: Box<[f64]> = vec![0f64; math.dim()].into(); state.write_position(math, &mut position); - let progress = Progress { - draw: self.draw_count, - chain: self.chain, - diverging: info.divergence_info.is_some(), - tuning: self.strategy.is_tuning(), - step_size: self.hamiltonian.step_size(), - num_steps: self.strategy.last_num_steps(), - }; - self.strategy.adapt( math, &mut self.options, @@ -205,6 +170,14 @@ where &state, &mut self.rng, )?; + let progress = Progress { + draw: self.draw_count, + chain: self.chain, + diverging: info.divergence_info.is_some(), + tuning: self.strategy.is_tuning(), + step_size: self.hamiltonian.step_size(), + num_steps: self.strategy.last_num_steps(), + }; self.draw_count += 1; @@ -213,453 +186,125 @@ where Ok((position, progress)) } + fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)> { + let (position, progress) = self.draw()?; + let mut math_ = self.math.borrow_mut(); + let math = math_.deref_mut(); + + let stats = self.extract_stats(&mut *math, self.stats_options); + let expanded = math.expand_vector(&mut self.rng, self.state.point().position())?; + + Ok((position, expanded, stats, progress)) + } + fn dim(&self) -> usize { self.math.borrow().dim() } + + fn math(&self) -> Ref<'_, M> { + self.math.borrow() + } } -pub struct NutsStatsBuilder> { - depth: PrimitiveBuilder, - maxdepth_reached: BooleanBuilder, - index_in_trajectory: PrimitiveBuilder, - logp: PrimitiveBuilder, - energy: PrimitiveBuilder, - chain: PrimitiveBuilder, - draw: PrimitiveBuilder, - energy_error: PrimitiveBuilder, - unconstrained: Option>>, - gradient: Option>>, - hamiltonian: >::Builder, - adapt: A::Builder, - point: <>::Point as SamplerStats>::Builder, - diverging: BooleanBuilder, - divergence_start: Option>>, - divergence_start_grad: Option>>, - divergence_end: Option>>, - divergence_momentum: Option>>, - divergence_msg: Option, +#[derive(Debug, nuts_derive::Storable)] +pub struct NutsStats, A: Storable

, D: Storable

> { + pub depth: u64, + pub maxdepth_reached: bool, + pub index_in_trajectory: i64, + pub logp: f64, + pub energy: f64, + pub chain: u64, + pub draw: u64, + pub energy_error: f64, + #[storable(dims("unconstrained_parameter"))] + pub unconstrained: Option>, + #[storable(dims("unconstrained_parameter"))] + pub gradient: Option>, + #[storable(flatten)] + pub hamiltonian: H, + #[storable(flatten)] + pub adapt: A, + #[storable(flatten)] + pub point: D, + pub diverging: bool, + #[storable(dims("unconstrained_parameter"))] + pub divergence_start: Option>, + #[storable(dims("unconstrained_parameter"))] + pub divergence_start_gradient: Option>, + #[storable(dims("unconstrained_parameter"))] + pub divergence_end: Option>, + #[storable(dims("unconstrained_parameter"))] + pub divergence_momentum: Option>, + //pub divergence_message: Option, + #[storable(ignore)] + _phantom: PhantomData P>, } pub struct StatOptions> { - pub adapt: A::StatOptions, - pub hamiltonian: >::StatOptions, - pub point: <>::Point as SamplerStats>::StatOptions, + pub adapt: A::StatsOptions, + pub hamiltonian: >::StatsOptions, + pub point: <>::Point as SamplerStats>::StatsOptions, } -impl> NutsStatsBuilder { - pub fn new_with_capacity( - stat_options: StatOptions, - settings: &impl Settings, - hamiltonian: &A::Hamiltonian, - adapt: &A, - point: &>::Point, - dim: usize, - options: &NutsOptions, - ) -> Self { - let capacity = settings.hint_num_tune() + settings.hint_num_draws(); - - let gradient = if options.store_gradient { - let items = PrimitiveBuilder::with_capacity(capacity); - Some(FixedSizeListBuilder::new(items, dim as i32)) - } else { - None - }; - - let unconstrained = if options.store_unconstrained { - let items = PrimitiveBuilder::with_capacity(capacity); - Some(FixedSizeListBuilder::with_capacity( - items, dim as i32, capacity, - )) - } else { - None - }; - - let (div_start, div_start_grad, div_end, div_mom, div_msg) = if options.store_divergences { - let start_location_prim = PrimitiveBuilder::new(); - let start_location_list = FixedSizeListBuilder::new(start_location_prim, dim as i32); - - let start_grad_prim = PrimitiveBuilder::new(); - let start_grad_list = FixedSizeListBuilder::new(start_grad_prim, dim as i32); - - let end_location_prim = PrimitiveBuilder::new(); - let end_location_list = FixedSizeListBuilder::new(end_location_prim, dim as i32); - - let momentum_location_prim = PrimitiveBuilder::new(); - let momentum_location_list = - FixedSizeListBuilder::new(momentum_location_prim, dim as i32); - - let msg_list = StringBuilder::new(); - - ( - Some(start_location_list), - Some(start_grad_list), - Some(end_location_list), - Some(momentum_location_list), - Some(msg_list), - ) - } else { - (None, None, None, None, None) - }; - - Self { - depth: PrimitiveBuilder::with_capacity(capacity), - maxdepth_reached: BooleanBuilder::with_capacity(capacity), - index_in_trajectory: PrimitiveBuilder::with_capacity(capacity), - logp: PrimitiveBuilder::with_capacity(capacity), - energy: PrimitiveBuilder::with_capacity(capacity), - chain: PrimitiveBuilder::with_capacity(capacity), - draw: PrimitiveBuilder::with_capacity(capacity), - energy_error: PrimitiveBuilder::with_capacity(capacity), - gradient, - unconstrained, - hamiltonian: hamiltonian.new_builder(stat_options.hamiltonian, settings, dim), - adapt: adapt.new_builder(stat_options.adapt, settings, dim), - point: point.new_builder(stat_options.point, settings, dim), - diverging: BooleanBuilder::with_capacity(capacity), - divergence_start: div_start, - divergence_start_grad: div_start_grad, - divergence_end: div_end, - divergence_momentum: div_mom, - divergence_msg: div_msg, - } +impl Clone for StatOptions +where + M: Math, + A: AdaptStrategy, +{ + fn clone(&self) -> Self { + *self } } -impl> StatTraceBuilder> - for NutsStatsBuilder +impl Copy for StatOptions +where + M: Math, + A: AdaptStrategy, { - fn append_value(&mut self, _math: Option<&mut M>, value: &NutsChain) { - let mut math_ = value.math.borrow_mut(); - let math = math_.deref_mut(); - let Self { - ref mut depth, - ref mut maxdepth_reached, - ref mut index_in_trajectory, - logp, - energy, - chain, - draw, - energy_error, - ref mut unconstrained, - ref mut gradient, - hamiltonian, - adapt, - point, - diverging, - ref mut divergence_start, - divergence_start_grad, - divergence_end, - divergence_momentum, - divergence_msg, - } = self; - - let info = value.last_info.as_ref().expect("Sampler has not started"); - let draw_point = value.state.point(); - - depth.append_value(info.depth); - maxdepth_reached.append_value(info.reached_maxdepth); - index_in_trajectory.append_value(draw_point.index_in_trajectory()); - logp.append_value(draw_point.logp()); - energy.append_value(draw_point.energy()); - chain.append_value(value.chain); - draw.append_value(value.draw_count); - diverging.append_value(info.divergence_info.is_some()); - energy_error.append_value(draw_point.energy_error()); - - fn add_slice>( - store: &mut Option>>, - values: Option, - n_dim: usize, - ) { - let Some(store) = store.as_mut() else { - return; - }; - - if let Some(values) = values.as_ref() { - store.values().append_slice(values.as_ref()); - store.append(true); - } else { - store.values().append_nulls(n_dim); - store.append(false); - } - } - - let n_dim = math.dim(); - add_slice(gradient, Some(math.box_array(draw_point.gradient())), n_dim); - add_slice( - unconstrained, - Some(math.box_array(draw_point.position())), - n_dim, - ); +} +impl> SamplerStats for NutsChain { + type Stats = NutsStats< + StatsDims, + >::Stats, + A::Stats, + <>::Point as SamplerStats>::Stats, + >; + type StatsOptions = StatOptions; + + fn extract_stats(&self, math: &mut M, options: Self::StatsOptions) -> Self::Stats { + let hamiltonian_stats = self.hamiltonian.extract_stats(math, options.hamiltonian); + let adapt_stats = self.strategy.extract_stats(math, options.adapt); + let point_stats = self.state.point().extract_stats(math, options.point); + let info = self.last_info.as_ref().expect("Sampler has not started"); + let point = self.state.point(); let div_info = info.divergence_info.as_ref(); - add_slice( - divergence_start, - div_info.and_then(|info| info.start_location.as_ref()), - n_dim, - ); - add_slice( - divergence_start_grad, - div_info.and_then(|info| info.start_gradient.as_ref()), - n_dim, - ); - add_slice( - divergence_end, - div_info.and_then(|info| info.end_location.as_ref()), - n_dim, - ); - add_slice( - divergence_momentum, - div_info.and_then(|info| info.start_momentum.as_ref()), - n_dim, - ); - - if let Some(div_msg) = divergence_msg.as_mut() { - if let Some(err) = div_info.and_then(|info| info.logp_function_error.as_ref()) { - div_msg.append_value(format!("{err}")); - } else { - div_msg.append_null(); - } - } - - hamiltonian.append_value(Some(math), &value.hamiltonian); - adapt.append_value(Some(math), &value.strategy); - point.append_value(Some(math), draw_point); - } - - fn finalize(self) -> Option { - let Self { - mut depth, - mut maxdepth_reached, - mut index_in_trajectory, - mut logp, - mut energy, - mut chain, - mut draw, - mut energy_error, - unconstrained, - gradient, - hamiltonian, - adapt, - point, - mut diverging, - divergence_start, - divergence_start_grad, - divergence_end, - divergence_momentum, - divergence_msg, - } = self; - - let mut fields = vec![ - Field::new("depth", DataType::UInt64, false), - Field::new("maxdepth_reached", DataType::Boolean, false), - Field::new("index_in_trajectory", DataType::Int64, false), - Field::new("logp", DataType::Float64, false), - Field::new("energy", DataType::Float64, false), - Field::new("chain", DataType::UInt64, false), - Field::new("draw", DataType::UInt64, false), - Field::new("diverging", DataType::Boolean, false), - Field::new("energy_error", DataType::Float64, false), - ]; - - let mut arrays: Vec> = vec![ - ArrayBuilder::finish(&mut depth), - ArrayBuilder::finish(&mut maxdepth_reached), - ArrayBuilder::finish(&mut index_in_trajectory), - ArrayBuilder::finish(&mut logp), - ArrayBuilder::finish(&mut energy), - ArrayBuilder::finish(&mut chain), - ArrayBuilder::finish(&mut draw), - ArrayBuilder::finish(&mut diverging), - ArrayBuilder::finish(&mut energy_error), - ]; - - fn merge_into>( - builder: B, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(struct_array) = builder.finalize() else { - return; - }; - - let (struct_fields, struct_arrays, bitmap) = struct_array.into_parts(); - assert!(bitmap.is_none()); - arrays.extend(struct_arrays); - fields.extend(struct_fields.into_iter().map(|x| x.deref().clone())); - } - - fn add_field( - mut builder: Option, - name: &str, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(mut builder) = builder.take() else { - return; - }; - - let array = ArrayBuilder::finish(&mut builder); - fields.push(Field::new(name, array.data_type().clone(), true)); - arrays.push(array); - } - - merge_into(hamiltonian, &mut arrays, &mut fields); - merge_into(adapt, &mut arrays, &mut fields); - merge_into(point, &mut arrays, &mut fields); - - add_field(gradient, "gradient", &mut arrays, &mut fields); - add_field( - unconstrained, - "unconstrained_draw", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start, - "divergence_start", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start_grad, - "divergence_start_gradient", - &mut arrays, - &mut fields, - ); - add_field(divergence_end, "divergence_end", &mut arrays, &mut fields); - add_field( - divergence_momentum, - "divergence_momentum", - &mut arrays, - &mut fields, - ); - add_field( - divergence_msg, - "divergence_messagem", - &mut arrays, - &mut fields, - ); - - let fields = Fields::from(fields); - Some(StructArray::new(fields, arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { - depth, - maxdepth_reached, - index_in_trajectory, - logp, - energy, - chain, - draw, - energy_error, - unconstrained, - gradient, - hamiltonian, - adapt, - point, - diverging, - divergence_start, - divergence_start_grad, - divergence_end, - divergence_momentum, - divergence_msg, - } = self; - - let mut fields = vec![ - Field::new("depth", DataType::UInt64, false), - Field::new("maxdepth_reached", DataType::Boolean, false), - Field::new("index_in_trajectory", DataType::Int64, false), - Field::new("logp", DataType::Float64, false), - Field::new("energy", DataType::Float64, false), - Field::new("chain", DataType::UInt64, false), - Field::new("draw", DataType::UInt64, false), - Field::new("diverging", DataType::Boolean, false), - Field::new("energy_error", DataType::Float64, false), - ]; - - let mut arrays: Vec> = vec![ - ArrayBuilder::finish_cloned(depth), - ArrayBuilder::finish_cloned(maxdepth_reached), - ArrayBuilder::finish_cloned(index_in_trajectory), - ArrayBuilder::finish_cloned(logp), - ArrayBuilder::finish_cloned(energy), - ArrayBuilder::finish_cloned(chain), - ArrayBuilder::finish_cloned(draw), - ArrayBuilder::finish_cloned(diverging), - ArrayBuilder::finish_cloned(energy_error), - ]; - - fn merge_into>( - builder: &B, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(struct_array) = builder.inspect() else { - return; - }; - - let (struct_fields, struct_arrays, bitmap) = struct_array.into_parts(); - assert!(bitmap.is_none()); - arrays.extend(struct_arrays); - fields.extend(struct_fields.into_iter().map(|x| x.deref().clone())); - } - fn add_field( - builder: &Option, - name: &str, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(builder) = builder.as_ref() else { - return; - }; - - let array = ArrayBuilder::finish_cloned(builder); - fields.push(Field::new(name, array.data_type().clone(), true)); - arrays.push(array); + NutsStats { + depth: info.depth, + maxdepth_reached: info.reached_maxdepth, + index_in_trajectory: point.index_in_trajectory(), + logp: point.logp(), + energy: point.energy(), + chain: self.chain, + draw: self.draw_count, + energy_error: point.energy_error(), + unconstrained: Some(math.box_array(point.position()).into_vec()), + gradient: Some(math.box_array(point.gradient()).into_vec()), + hamiltonian: hamiltonian_stats, + adapt: adapt_stats, + point: point_stats, + diverging: div_info.is_some(), + divergence_start: div_info + .and_then(|d| d.start_location.as_ref().map(|v| v.as_ref().to_vec())), + divergence_start_gradient: div_info + .and_then(|d| d.start_gradient.as_ref().map(|v| v.as_ref().to_vec())), + divergence_end: div_info + .and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec())), + divergence_momentum: div_info + .and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec())), + //divergence_message: self.divergence_msg.clone(), + _phantom: PhantomData, } - - merge_into(hamiltonian, &mut arrays, &mut fields); - merge_into(adapt, &mut arrays, &mut fields); - merge_into(point, &mut arrays, &mut fields); - - add_field(gradient, "gradient", &mut arrays, &mut fields); - add_field( - unconstrained, - "unconstrained_draw", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start, - "divergence_start", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start_grad, - "divergence_start_gradient", - &mut arrays, - &mut fields, - ); - add_field(divergence_end, "divergence_end", &mut arrays, &mut fields); - add_field( - divergence_momentum, - "divergence_momentum", - &mut arrays, - &mut fields, - ); - add_field( - divergence_msg, - "divergence_messagem", - &mut arrays, - &mut fields, - ); - - let fields = Fields::from(fields); - Some(StructArray::new(fields, arrays, None)) } } diff --git a/src/cpu_math.rs b/src/cpu_math.rs index 92e2ae5..e7d1302 100644 --- a/src/cpu_math.rs +++ b/src/cpu_math.rs @@ -1,7 +1,8 @@ -use std::{error::Error, fmt::Debug, mem::replace}; +use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace}; use faer::{Col, Mat}; -use itertools::{izip, Itertools}; +use itertools::{Itertools, izip}; +use nuts_storable::{HasDims, Storable, Value}; use thiserror::Error; use crate::{ @@ -27,6 +28,38 @@ impl CpuMath { pub enum CpuMathError { #[error("Error during array operation")] ArrayError(), + #[error("Error during point expansion")] + ExpandError(), +} + +impl HasDims for CpuMath { + fn dim_sizes(&self) -> HashMap { + self.logp_func.dim_sizes() + } + + fn coords(&self) -> HashMap { + self.logp_func.coords() + } +} + +pub struct ExpandedVectorWrapper(F::ExpandedVector); + +impl Storable> for ExpandedVectorWrapper { + fn names(parent: &CpuMath) -> Vec<&str> { + F::ExpandedVector::names(&parent.logp_func) + } + + fn item_type(parent: &CpuMath, item: &str) -> nuts_storable::ItemType { + F::ExpandedVector::item_type(&parent.logp_func, item) + } + + fn dims<'a>(parent: &'a CpuMath, item: &str) -> Vec<&'a str> { + F::ExpandedVector::dims(&parent.logp_func, item) + } + + fn get_all(&self, parent: &CpuMath) -> Vec<(&str, Option)> { + self.0.get_all(&parent.logp_func) + } } impl Math for CpuMath { @@ -35,7 +68,8 @@ impl Math for CpuMath { type EigValues = Col; type LogpErr = F::LogpError; type Err = CpuMathError; - type TransformParams = F::TransformParams; + type FlowParameters = F::FlowParameters; + type ExpandedVector = ExpandedVectorWrapper; fn new_array(&mut self) -> Self::Vector { Col::zeros(self.dim()) @@ -94,6 +128,54 @@ impl Math for CpuMath { self.logp_func.dim() } + fn expand_vector( + &mut self, + rng: &mut R, + array: &Self::Vector, + ) -> Result { + Ok(ExpandedVectorWrapper( + self.logp_func.expand_vector( + rng, + array + .try_as_col_major() + .ok_or(CpuMathError::ExpandError())? + .as_slice(), + )?, + )) + } + + fn vector_coord(&self) -> Option { + self.logp_func.vector_coord() + } + + fn init_position( + &mut self, + rng: &mut R, + position: &mut Self::Vector, + gradient: &mut Self::Vector, + ) -> Result { + let pos = position + .try_as_col_major_mut() + .expect("Array is not contiguous") + .as_slice_mut(); + + pos.iter_mut().for_each(|x| { + let val: f64 = rng.random(); + *x = val * 2f64 - 1f64 + }); + + self.logp_func.logp( + position + .try_as_col_major() + .expect("Array is not contiguous") + .as_slice(), + gradient + .try_as_col_major_mut() + .expect("Array is not contiguous") + .as_slice_mut(), + ) + } + fn scalar_prods3( &mut self, positive1: &Self::Vector, @@ -424,7 +506,7 @@ impl Math for CpuMath { fn inv_transform_normalize( &mut self, - params: &Self::TransformParams, + params: &Self::FlowParameters, untransformed_position: &Self::Vector, untransofrmed_gradient: &Self::Vector, transformed_position: &mut Self::Vector, @@ -453,7 +535,7 @@ impl Math for CpuMath { fn init_from_untransformed_position( &mut self, - params: &Self::TransformParams, + params: &Self::FlowParameters, untransformed_position: &Self::Vector, untransformed_gradient: &mut Self::Vector, transformed_position: &mut Self::Vector, @@ -482,7 +564,7 @@ impl Math for CpuMath { fn init_from_transformed_position( &mut self, - params: &Self::TransformParams, + params: &Self::FlowParameters, untransformed_position: &mut Self::Vector, untransformed_gradient: &mut Self::Vector, transformed_position: &Self::Vector, @@ -512,7 +594,7 @@ impl Math for CpuMath { untransformed_positions: impl ExactSizeIterator, untransformed_gradients: impl ExactSizeIterator, untransformed_logp: impl ExactSizeIterator, - params: &'a mut Self::TransformParams, + params: &'a mut Self::FlowParameters, ) -> Result<(), Self::LogpErr> { self.logp_func.update_transformation( rng, @@ -529,7 +611,7 @@ impl Math for CpuMath { untransformed_position: &Self::Vector, untransfogmed_gradient: &Self::Vector, chain: u64, - ) -> Result { + ) -> Result { self.logp_func.new_transformation( rng, untransformed_position @@ -544,21 +626,33 @@ impl Math for CpuMath { ) } - fn transformation_id(&self, params: &Self::TransformParams) -> Result { + fn transformation_id(&self, params: &Self::FlowParameters) -> Result { self.logp_func.transformation_id(params) } } -pub trait CpuLogpFunc { +pub trait CpuLogpFunc: HasDims { type LogpError: Debug + Send + Sync + Error + LogpError + 'static; - type TransformParams; + type FlowParameters; + type ExpandedVector: Storable; fn dim(&self) -> usize; fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result; + fn expand_vector( + &mut self, + rng: &mut R, + array: &[f64], + ) -> Result + where + R: rand::Rng + ?Sized; + + fn vector_coord(&self) -> Option { + None + } fn inv_transform_normalize( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &[f64], _untransformed_gradient: &[f64], _transformed_position: &mut [f64], @@ -569,7 +663,7 @@ pub trait CpuLogpFunc { fn init_from_untransformed_position( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &[f64], _untransformed_gradient: &mut [f64], _transformed_position: &mut [f64], @@ -580,7 +674,7 @@ pub trait CpuLogpFunc { fn init_from_transformed_position( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &mut [f64], _untransformed_gradient: &mut [f64], _transformed_position: &[f64], @@ -595,7 +689,7 @@ pub trait CpuLogpFunc { _untransformed_positions: impl ExactSizeIterator, _untransformed_gradients: impl ExactSizeIterator, _untransformed_logp: impl ExactSizeIterator, - _params: &'a mut Self::TransformParams, + _params: &'a mut Self::FlowParameters, ) -> Result<(), Self::LogpError> { unimplemented!() } @@ -606,11 +700,20 @@ pub trait CpuLogpFunc { _untransformed_position: &[f64], _untransformed_gradient: &[f64], _chain: u64, - ) -> Result { + ) -> Result { unimplemented!() } - fn transformation_id(&self, _params: &Self::TransformParams) -> Result { + fn transformation_id(&self, _params: &Self::FlowParameters) -> Result { unimplemented!() } } + +impl Clone for CpuMath { + fn clone(&self) -> Self { + Self { + logp_func: self.logp_func.clone(), + arch: self.arch, + } + } +} diff --git a/src/euclidean_hamiltonian.rs b/src/euclidean_hamiltonian.rs index ae5be9a..5059171 100644 --- a/src/euclidean_hamiltonian.rs +++ b/src/euclidean_hamiltonian.rs @@ -1,17 +1,16 @@ use std::marker::PhantomData; use std::sync::Arc; -use arrow::array::{ArrayBuilder, Float64Builder, StructArray}; -use arrow::datatypes::{DataType, Field}; +use nuts_derive::Storable; +use nuts_storable::HasDims; +use crate::LogpError; use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point}; use crate::mass_matrix::MassMatrix; use crate::math_base::Math; use crate::nuts::{Collector, NutsError}; -use crate::sampler::Settings; -use crate::sampler_stats::{SamplerStats, StatTraceBuilder}; +use crate::sampler_stats::{SamplerStats, StatsDims}; use crate::state::{State, StatePool}; -use crate::LogpError; pub struct EuclideanHamiltonian> { pub(crate) mass_matrix: Mass, @@ -51,35 +50,15 @@ pub struct EuclideanPoint { pub initial_energy: f64, } -pub struct PointStatsBuilder {} - -impl StatTraceBuilder> for PointStatsBuilder { - fn append_value(&mut self, _math: Option<&mut M>, _value: &EuclideanPoint) { - let Self {} = self; - } - - fn finalize(self) -> Option { - let Self {} = self; - None - } - - fn inspect(&self) -> Option { - let Self {} = self; - None - } -} +#[derive(Debug, Storable)] +pub struct PointStats {} impl SamplerStats for EuclideanPoint { - type Builder = PointStatsBuilder; - type StatOptions = (); + type Stats = PointStats; + type StatsOptions = (); - fn new_builder( - &self, - _stat_options: Self::StatOptions, - _settings: &impl Settings, - _dim: usize, - ) -> Self::Builder { - Self::Builder {} + fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { + PointStats {} } } @@ -216,87 +195,24 @@ impl Point for EuclideanPoint { } } -pub struct PotentialStatsBuilder { - mass_matrix: B, - step_size: Float64Builder, -} - -impl> StatTraceBuilder> - for PotentialStatsBuilder -{ - fn append_value(&mut self, math: Option<&mut M>, value: &EuclideanHamiltonian) { - let math = math.expect("Sampler stats needs math"); - let Self { - mass_matrix, - step_size, - } = self; - - mass_matrix.append_value(Some(math), &value.mass_matrix); - step_size.append_value(value.step_size); - } - - fn finalize(self) -> Option { - let Self { - mass_matrix, - mut step_size, - } = self; - - let mut fields = vec![Field::new("step_size", DataType::Float64, false)]; - let mut arrays = vec![ArrayBuilder::finish(&mut step_size)]; - - if let Some(mass_matrix) = mass_matrix.finalize() { - let (m_fields, m_data, m_bitmap) = mass_matrix.into_parts(); - assert!(m_bitmap.is_none()); - fields.extend( - m_fields - .into_iter() - .map(|v| Arc::unwrap_or_clone(v.to_owned())), - ); - arrays.extend(m_data); - } - - Some(StructArray::new(fields.into(), arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { - mass_matrix, - step_size, - } = self; - - let mut fields = vec![Field::new("step_size", DataType::Float64, false)]; - let mut arrays = vec![ArrayBuilder::finish_cloned(step_size)]; - - if let Some(mass_matrix) = mass_matrix.inspect() { - let (m_fields, m_data, m_bitmap) = mass_matrix.into_parts(); - assert!(m_bitmap.is_none()); - fields.extend( - m_fields - .into_iter() - .map(|v| Arc::unwrap_or_clone(v.to_owned())), - ); - arrays.extend(m_data); - } - - Some(StructArray::new(fields.into(), arrays, None)) - } +#[derive(Debug, Storable)] +pub struct PotentialStats> { + #[storable(flatten)] + pub mass_matrix: B, + pub step_size: f64, + #[storable(ignore)] + _phantom: PhantomData P>, } impl> SamplerStats for EuclideanHamiltonian { - type Builder = PotentialStatsBuilder; - type StatOptions = Mass::StatOptions; + type Stats = PotentialStats; + type StatsOptions = Mass::StatsOptions; - fn new_builder( - &self, - stat_options: Self::StatOptions, - settings: &impl Settings, - dim: usize, - ) -> Self::Builder { - Self::Builder { - mass_matrix: self.mass_matrix.new_builder(stat_options, settings, dim), - step_size: Float64Builder::with_capacity( - settings.hint_num_draws() + settings.hint_num_tune(), - ), + fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats { + PotentialStats { + mass_matrix: self.mass_matrix.extract_stats(math, opt), + step_size: self.step_size, + _phantom: PhantomData, } } } diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index 1584904..e4abcf8 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use rand_distr::{Distribution, StandardUniform}; use crate::{ + Math, NutsError, nuts::Collector, sampler_stats::SamplerStats, state::{State, StatePool}, - Math, NutsError, }; /// Details about a divergence that might have occured during sampling diff --git a/src/lib.rs b/src/lib.rs index b4798a0..b722803 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,9 +10,10 @@ //! //! ``` //! use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, DiagGradNutsSettings, Chain, Progress, -//! Settings}; +//! Settings, HasDims}; //! use thiserror::Error; //! use rand::thread_rng; +//! use std::collections::HashMap; //! //! // Define a function that computes the unnormalized posterior density //! // and its gradient. @@ -26,11 +27,18 @@ //! fn is_recoverable(&self) -> bool { false } //! } //! +//! impl HasDims for PosteriorDensity { +//! fn dim_sizes(&self) -> HashMap { +//! vec![("unconstrained_parameter".to_string(), self.dim() as u64)].into_iter().collect() +//! } +//! } +//! //! impl CpuLogpFunc for PosteriorDensity { //! type LogpError = PosteriorLogpError; +//! type ExpandedVector = Vec; //! //! // Only used for transforming adaptation. -//! type TransformParams = (); +//! type FlowParameters = (); //! //! // We define a 10 dimensional normal distribution //! fn dim(&self) -> usize { 10 } @@ -50,6 +58,10 @@ //! .sum(); //! return Ok(logp) //! } +//! +//! fn expand_vector(&mut self, rng: &mut R, position: &[f64]) -> Result, nuts_rs::CpuMathError> { +//! Ok(position.to_vec()) +//! } //! } //! //! // We get the default sampler arguments @@ -78,6 +90,8 @@ //! //! Users can also implement the `Model` trait for more control and parallel sampling. //! +//! See the examples directory in the repository for more examples. +//! //! ## Implementation details //! //! This crate mostly follows the implementation of NUTS in [Stan](https://mc-stan.org) and @@ -89,33 +103,45 @@ mod chain; mod cpu_math; mod euclidean_hamiltonian; mod hamiltonian; -mod low_rank_mass_matrix; mod mass_matrix; -mod mass_matrix_adapt; mod math; mod math_base; +mod model; mod nuts; mod sampler; mod sampler_stats; mod state; mod stepsize; -mod stepsize_adapt; +mod storage; mod transform_adapt_strategy; mod transformed_hamiltonian; +pub use nuts_derive::Storable; +pub use nuts_storable::{HasDims, ItemType, Storable, Value}; + pub use adapt_strategy::EuclideanAdaptOptions; pub use chain::Chain; -pub use cpu_math::{CpuLogpFunc, CpuMath}; +pub use cpu_math::{CpuLogpFunc, CpuMath, CpuMathError}; pub use hamiltonian::DivergenceInfo; pub use math_base::{LogpError, Math}; +pub use model::Model; pub use nuts::NutsError; pub use sampler::{ - sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage, - LowRankNutsSettings, Model, NutsSettings, Progress, ProgressCallback, Sampler, - SamplerWaitResult, Settings, Trace, TransformedNutsSettings, + ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress, + ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings, + sample_sequentially, }; +pub use sampler_stats::SamplerStats; -pub use low_rank_mass_matrix::LowRankSettings; -pub use mass_matrix_adapt::DiagAdaptExpSettings; -pub use stepsize_adapt::DualAverageSettings; +pub use mass_matrix::DiagAdaptExpSettings; +pub use mass_matrix::LowRankSettings; +pub use stepsize::{AdamOptions, StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings}; pub use transform_adapt_strategy::TransformedSettings; + +#[cfg(feature = "zarr")] +pub use storage::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage}; + +pub use storage::{CsvConfig, CsvTraceStorage}; +pub use storage::{HashMapConfig, HashMapValue}; +#[cfg(feature = "ndarray")] +pub use storage::{NdarrayConfig, NdarrayTrace, NdarrayValue}; diff --git a/src/mass_matrix_adapt.rs b/src/mass_matrix/adapt.rs similarity index 93% rename from src/mass_matrix_adapt.rs rename to src/mass_matrix/adapt.rs index 419243a..28166ba 100644 --- a/src/mass_matrix_adapt.rs +++ b/src/mass_matrix/adapt.rs @@ -1,14 +1,16 @@ use std::marker::PhantomData; +use nuts_derive::Storable; use rand::Rng; +use serde::Serialize; +use super::mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance}; use crate::{ + Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point, - mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance}, nuts::{Collector, NutsOptions}, sampler_stats::SamplerStats, - Math, NutsError, Settings, }; const LOWER_LIMIT: f64 = 1e-20f64; const UPPER_LIMIT: f64 = 1e20f64; @@ -17,7 +19,7 @@ const INIT_LOWER_LIMIT: f64 = 1e-20f64; const INIT_UPPER_LIMIT: f64 = 1e20f64; /// Settings for mass matrix adaptation -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Serialize)] pub struct DiagAdaptExpSettings { pub store_mass_matrix: bool, pub use_grad_based_estimate: bool, @@ -41,6 +43,18 @@ pub struct Strategy { _phantom: PhantomData, } +#[derive(Debug, Storable)] +pub struct Stats {} + +impl SamplerStats for Strategy { + type Stats = Stats; + type StatsOptions = (); + + fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { + Stats {} + } +} + pub trait MassMatrixAdaptStrategy: SamplerStats { type MassMatrix: MassMatrix; type Collector: Collector>; @@ -165,18 +179,3 @@ impl MassMatrixAdaptStrategy for Strategy { DrawGradCollector::new(math) } } - -pub type StatsBuilder = (); - -impl SamplerStats for Strategy { - type Builder = StatsBuilder; - type StatOptions = (); - - fn new_builder( - &self, - _stat_options: Self::StatOptions, - _settings: &impl Settings, - _dim: usize, - ) -> Self::Builder { - } -} diff --git a/src/low_rank_mass_matrix.rs b/src/mass_matrix/low_rank.rs similarity index 73% rename from src/low_rank_mass_matrix.rs rename to src/mass_matrix/low_rank.rs index 316dbe4..99f10cc 100644 --- a/src/low_rank_mass_matrix.rs +++ b/src/mass_matrix/low_rank.rs @@ -1,19 +1,15 @@ use std::collections::VecDeque; -use arrow::{ - array::{ArrayBuilder, FixedSizeListBuilder, ListBuilder, PrimitiveBuilder, StructArray}, - datatypes::{Field, Float64Type, UInt64Type}, -}; use faer::{Col, ColRef, Mat, MatRef, Scale}; use itertools::Itertools; +use nuts_derive::Storable; +use serde::Serialize; +use super::adapt::MassMatrixAdaptStrategy; +use super::mass_matrix::{DrawGradCollector, MassMatrix}; use crate::{ - euclidean_hamiltonian::EuclideanPoint, - hamiltonian::Point, - mass_matrix::{DrawGradCollector, MassMatrix}, - mass_matrix_adapt::MassMatrixAdaptStrategy, - sampler_stats::{SamplerStats, StatTraceBuilder}, - Math, NutsError, + Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point, + sampler_stats::SamplerStats, }; fn mat_all_finite(mat: &MatRef) -> bool { @@ -119,7 +115,7 @@ impl LowRankMassMatrix { } } -#[derive(Clone, Debug, Copy)] +#[derive(Clone, Debug, Copy, Serialize)] pub struct LowRankSettings { pub store_mass_matrix: bool, pub gamma: f64, @@ -136,155 +132,38 @@ impl Default for LowRankSettings { } } -pub struct MatrixBuilder { - eigenvals: Option>>, - stds: Option>>, - num_eigenvalues: PrimitiveBuilder, -} - -impl StatTraceBuilder> for MatrixBuilder { - fn append_value(&mut self, math: Option<&mut M>, value: &LowRankMassMatrix) { - let math = math.expect("Need reference to math for stats"); - let Self { - eigenvals, - stds, - num_eigenvalues, - } = self; - - if let Some(store) = eigenvals { - if let Some(inner) = &value.inner { - store - .values() - .append_slice(&math.eigs_as_array(&inner.vals)); - store.append(true); - } else { - store.append(false); - } - } - if let Some(store) = stds { - store.values().append_slice(&math.box_array(&value.stds)); - store.append(true); - } - - num_eigenvalues.append_value( - value - .inner - .as_ref() - .map(|inner| inner.num_eigenvalues) - .unwrap_or(0), - ); - } - - fn finalize(self) -> Option { - let Self { - eigenvals, - stds, - mut num_eigenvalues, - } = self; - - let num_eigenvalues = ArrayBuilder::finish(&mut num_eigenvalues); - - let mut fields = vec![Field::new( - "mass_matrix_num_eigenvalues", - arrow::datatypes::DataType::UInt64, - false, - )]; - let mut arrays = vec![num_eigenvalues]; - - if let Some(mut eigenvals) = eigenvals { - let eigenvals = ArrayBuilder::finish(&mut eigenvals); - fields.push(Field::new( - "mass_matrix_eigenvals", - eigenvals.data_type().clone(), - true, - )); - - arrays.push(eigenvals); - } - - if let Some(mut stds) = stds { - let stds = ArrayBuilder::finish(&mut stds); - fields.push(Field::new( - "mass_matrix_stds", - stds.data_type().clone(), - true, - )); - - arrays.push(stds); - } - - Some(StructArray::new(fields.into(), arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { - ref eigenvals, - ref stds, - ref num_eigenvalues, - } = self; - - let num_eigenvalues = ArrayBuilder::finish_cloned(num_eigenvalues); - - let mut fields = vec![Field::new( - "mass_matrix_num_eigenvalues", - arrow::datatypes::DataType::UInt64, - false, - )]; - let mut arrays = vec![num_eigenvalues]; - - if let Some(eigenvals) = &eigenvals { - let eigenvals = ArrayBuilder::finish_cloned(eigenvals); - fields.push(Field::new( - "mass_matrix_eigenvals", - eigenvals.data_type().clone(), - true, - )); - - arrays.push(eigenvals); - } - - if let Some(stds) = &stds { - let stds = ArrayBuilder::finish_cloned(stds); - fields.push(Field::new( - "mass_matrix_stds", - stds.data_type().clone(), - true, - )); - - arrays.push(stds); - } - Some(StructArray::new(fields.into(), arrays, None)) - } +#[derive(Debug, Storable)] +pub struct MatrixStats { + pub eigvals: Option>, + pub stds: Option>, + pub num_eigenvalues: u64, } impl SamplerStats for LowRankMassMatrix { - type Builder = MatrixBuilder; - type StatOptions = (); + type Stats = MatrixStats; + type StatsOptions = (); - fn new_builder( - &self, - _stat_options: Self::StatOptions, - _settings: &impl crate::Settings, - dim: usize, - ) -> Self::Builder { - let num_eigenvalues = PrimitiveBuilder::new(); + fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { if self.settings.store_mass_matrix { - let items = PrimitiveBuilder::new(); - let eigenvals = Some(ListBuilder::new(items)); - - let items = PrimitiveBuilder::new(); - let stds = Some(FixedSizeListBuilder::new(items, dim as _)); - - MatrixBuilder { - eigenvals, - stds, - num_eigenvalues, + let eigvals = self + .inner + .as_ref() + .map(|inner| math.eigs_as_array(&inner.vals)); + let stds = Some(math.box_array(&self.stds)); + MatrixStats { + eigvals: eigvals.map(|x| x.into_vec()), + stds: stds.map(|x| x.into_vec()), + num_eigenvalues: self + .inner + .as_ref() + .map(|inner| inner.num_eigenvalues) + .unwrap_or(0), } } else { - MatrixBuilder { - eigenvals: None, + MatrixStats { + eigvals: None, stds: None, - num_eigenvalues, + num_eigenvalues: 0, } } } @@ -343,23 +222,6 @@ pub struct Stats { } */ -#[derive(Debug)] -pub struct Builder {} - -impl StatTraceBuilder for Builder { - fn append_value(&mut self, _math: Option<&mut M>, _value: &LowRankMassMatrixStrategy) { - let Self {} = self; - } - - fn finalize(self) -> Option { - None - } - - fn inspect(&self) -> Option { - None - } -} - #[derive(Debug)] pub struct LowRankMassMatrixStrategy { draws: VecDeque>, @@ -445,7 +307,7 @@ impl LowRankMassMatrixStrategy { let filtered = vals .iter() .zip(vecs.col_iter()) - .filter(|(&val, _)| { + .filter(|&(&val, _)| { (val > self.settings.eigval_cutoff) | (val < self.settings.eigval_cutoff.recip()) }) .collect_vec(); @@ -466,7 +328,8 @@ impl LowRankMassMatrixStrategy { fn rescale_points(draws: &mut Mat, grads: &mut Mat) -> Col { let (ndim, ndraws) = draws.shape(); - let stds = Col::from_fn(ndim, |col| { + + Col::from_fn(ndim, |col| { let draw_mean = draws.row(col).sum() / (ndraws as f64); let grad_mean = grads.row(col).sum() / (ndraws as f64); let draw_std: f64 = draws @@ -497,8 +360,7 @@ fn rescale_points(draws: &mut Mat, grads: &mut Mat) -> Col { .for_each(|val| *val = (*val - grad_mean) * grad_scale); std - }); - stds + }) } fn estimate_mass_matrix( @@ -563,17 +425,10 @@ fn spd_mean(cov_draws: Mat, cov_grads: Mat) -> Option> { } impl SamplerStats for LowRankMassMatrixStrategy { - type Builder = Builder; - type StatOptions = (); + type Stats = (); + type StatsOptions = (); - fn new_builder( - &self, - _stat_options: Self::StatOptions, - _settings: &impl crate::Settings, - _dim: usize, - ) -> Self::Builder { - Builder {} - } + fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats {} } impl MassMatrixAdaptStrategy for LowRankMassMatrixStrategy { @@ -646,13 +501,11 @@ mod test { use std::ops::AddAssign; use equator::Cmp; - use faer::{utils::approx::ApproxEq, Col, Mat}; - use rand::{rngs::SmallRng, Rng, SeedableRng}; + use faer::{Col, Mat, utils::approx::ApproxEq}; + use rand::{Rng, SeedableRng, rngs::SmallRng}; use rand_distr::StandardNormal; - use crate::low_rank_mass_matrix::mat_all_finite; - - use super::{estimate_mass_matrix, spd_mean}; + use super::{estimate_mass_matrix, mat_all_finite, spd_mean}; #[test] fn test_spd_mean() { diff --git a/src/mass_matrix.rs b/src/mass_matrix/mass_matrix.rs similarity index 67% rename from src/mass_matrix.rs rename to src/mass_matrix/mass_matrix.rs index 2f0219c..288c9a9 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix/mass_matrix.rs @@ -1,16 +1,8 @@ -use arrow::{ - array::{ArrayBuilder, FixedSizeListBuilder, PrimitiveBuilder, StructArray}, - datatypes::{Field, Float64Type}, -}; +use nuts_derive::Storable; use crate::{ - euclidean_hamiltonian::EuclideanPoint, - hamiltonian::Point, - math_base::Math, - nuts::Collector, - sampler::Settings, - sampler_stats::{SamplerStats, StatTraceBuilder}, - state::State, + euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point, math_base::Math, nuts::Collector, + sampler_stats::SamplerStats, state::State, }; pub trait MassMatrix: SamplerStats { @@ -24,10 +16,6 @@ pub trait MassMatrix: SamplerStats { ); } -pub struct NullCollector {} - -impl> Collector for NullCollector {} - #[derive(Debug)] pub struct DiagMassMatrix { inv_stds: M::Vector, @@ -35,68 +23,23 @@ pub struct DiagMassMatrix { store_mass_matrix: bool, } -pub struct DiagMassMatrixStatsBuilder { - mass_matrix_inv: Option>>, -} - -impl StatTraceBuilder> for DiagMassMatrixStatsBuilder { - fn append_value(&mut self, math: Option<&mut M>, value: &DiagMassMatrix) { - let math = math.expect("Need reference to math for stats"); - let Self { mass_matrix_inv } = self; - - if let Some(store) = mass_matrix_inv { - let values = math.box_array(&value.variance); - store.values().append_slice(&values); - store.append(true); - } - } - - fn finalize(self) -> Option { - let Self { mass_matrix_inv } = self; - - let array = ArrayBuilder::finish(&mut mass_matrix_inv?); - - let fields = vec![Field::new( - "mass_matrix_inv", - array.data_type().clone(), - true, - )]; - let arrays = vec![array]; - Some(StructArray::new(fields.into(), arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { mass_matrix_inv } = self; - - let array = ArrayBuilder::finish_cloned(mass_matrix_inv.as_ref()?); - let fields = vec![Field::new( - "mass_matrix_inv", - array.data_type().clone(), - true, - )]; - let arrays = vec![array]; - Some(StructArray::new(fields.into(), arrays, None)) - } +#[derive(Debug, Storable)] +pub struct DiagMassMatrixStats { + #[storable(dims("unconstrained_parameter"))] + pub mass_matrix_inv: Option>, } impl SamplerStats for DiagMassMatrix { - type Builder = DiagMassMatrixStatsBuilder; - type StatOptions = (); + type Stats = DiagMassMatrixStats; + type StatsOptions = (); - fn new_builder( - &self, - _stat_options: Self::StatOptions, - _settings: &impl Settings, - dim: usize, - ) -> Self::Builder { + fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { if self.store_mass_matrix { - let items = PrimitiveBuilder::new(); - let values = FixedSizeListBuilder::new(items, dim as _); - Self::Builder { - mass_matrix_inv: Some(values), + DiagMassMatrixStats { + mass_matrix_inv: Some(math.box_array(&self.variance).into_vec()), } } else { - Self::Builder { + DiagMassMatrixStats { mass_matrix_inv: None, } } diff --git a/src/mass_matrix/mod.rs b/src/mass_matrix/mod.rs new file mode 100644 index 0000000..8409350 --- /dev/null +++ b/src/mass_matrix/mod.rs @@ -0,0 +1,10 @@ +mod adapt; +mod low_rank; +mod mass_matrix; + +pub use adapt::DiagAdaptExpSettings; +pub(crate) use adapt::MassMatrixAdaptStrategy; +pub(crate) use adapt::Strategy; +pub use low_rank::LowRankSettings; +pub(crate) use low_rank::{LowRankMassMatrix, LowRankMassMatrixStrategy}; +pub(crate) use mass_matrix::{DiagMassMatrix, MassMatrix}; diff --git a/src/math_base.rs b/src/math_base.rs index 82e5b50..65d6d4b 100644 --- a/src/math_base.rs +++ b/src/math_base.rs @@ -1,5 +1,8 @@ use std::{error::Error, fmt::Debug}; +use nuts_storable::{HasDims, Storable, Value}; +use rand::Rng; + /// Errors that happen when we evaluate the logp and gradient function pub trait LogpError: std::error::Error + Send { /// Unrecoverable errors during logp computation stop sampling, @@ -7,13 +10,14 @@ pub trait LogpError: std::error::Error + Send { fn is_recoverable(&self) -> bool; } -pub trait Math { +pub trait Math: HasDims { type Vector: Debug; type EigVectors: Debug; type EigValues: Debug; type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static; type Err: Debug + Send + Sync + Error + 'static; - type TransformParams; + type FlowParameters; + type ExpandedVector: Storable; fn new_array(&mut self) -> Self::Vector; @@ -45,8 +49,27 @@ pub trait Math { fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result; + fn init_position( + &mut self, + rng: &mut R, + position: &mut Self::Vector, + gradient: &mut Self::Vector, + ) -> Result; + + /// Expand a vector into a larger representation, to for instance + /// compute deterministic values that are to be stored in the trace. + fn expand_vector( + &mut self, + rng: &mut R, + array: &Self::Vector, + ) -> Result; + fn dim(&self) -> usize; + fn vector_coord(&self) -> Option { + None + } + fn scalar_prods3( &mut self, positive1: &Self::Vector, @@ -144,7 +167,7 @@ pub trait Math { fn inv_transform_normalize( &mut self, - params: &Self::TransformParams, + params: &Self::FlowParameters, untransformed_position: &Self::Vector, untransofrmed_gradient: &Self::Vector, transformed_position: &mut Self::Vector, @@ -153,7 +176,7 @@ pub trait Math { fn init_from_untransformed_position( &mut self, - params: &Self::TransformParams, + params: &Self::FlowParameters, untransformed_position: &Self::Vector, untransformed_gradient: &mut Self::Vector, transformed_position: &mut Self::Vector, @@ -162,7 +185,7 @@ pub trait Math { fn init_from_transformed_position( &mut self, - params: &Self::TransformParams, + params: &Self::FlowParameters, untransformed_position: &mut Self::Vector, untransformed_gradient: &mut Self::Vector, transformed_position: &Self::Vector, @@ -175,7 +198,7 @@ pub trait Math { untransformed_positions: impl ExactSizeIterator, untransformed_gradients: impl ExactSizeIterator, untransformed_logps: impl ExactSizeIterator, - params: &'a mut Self::TransformParams, + params: &'a mut Self::FlowParameters, ) -> Result<(), Self::LogpErr>; fn new_transformation( @@ -184,7 +207,7 @@ pub trait Math { untransformed_position: &Self::Vector, untransfogmed_gradient: &Self::Vector, chain: u64, - ) -> Result; + ) -> Result; - fn transformation_id(&self, params: &Self::TransformParams) -> Result; + fn transformation_id(&self, params: &Self::FlowParameters) -> Result; } diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..eb3dbac --- /dev/null +++ b/src/model.rs @@ -0,0 +1,37 @@ +//! Core abstractions for MCMC models. +//! +//! Provides the `Model` trait which defines the interface for MCMC models, +//! including the math backend and initialization methods needed for sampling. + +use anyhow::Result; +use rand::Rng; + +use crate::math_base::Math; + +/// Trait for MCMC models with associated math backend and initialization. +/// +/// Defines the interface for models that can be used with MCMC sampling algorithms. +/// Provides access to mathematical operations needed for sampling and methods for +/// initializing the sampling position. +/// +/// The trait is thread-safe to enable parallel sampling scenarios. +pub trait Model: Send + Sync + 'static { + /// The math backend used by this MCMC model. + /// + /// Specifies which math implementation will be used for computing log probability + /// densities, gradients, and other operations required during sampling. + /// + /// The lifetime parameter allows the math backend to borrow from the model instance. + type Math<'model>: Math + where + Self: 'model; + + /// Returns the math backend for this model. + fn math(&self, rng: &mut R) -> Result>; + + /// Initializes the starting position for MCMC sampling. + /// + /// Sets initial values for the parameter vector. The starting position should + /// be in a reasonable region where the log probability density is finite. + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()>; +} diff --git a/src/nuts.rs b/src/nuts.rs index 0072c69..6c5af53 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -250,6 +250,7 @@ impl, C: Collector> NutsTree { pub struct NutsOptions { pub maxdepth: u64, + pub mindepth: u64, pub store_gradient: bool, pub store_unconstrained: bool, pub check_turning: bool, @@ -286,9 +287,13 @@ where tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { - let info = tree.info(false, None); - collector.register_draw(math, &tree.draw, &info); - return Ok((tree.draw, info)); + if tree.depth < options.mindepth { + tree + } else { + let info = tree.info(false, None); + collector.register_draw(math, &tree.draw, &info); + return Ok((tree.draw, info)); + } } ExtendResult::Diverging(tree, info) => { let info = tree.info(false, Some(info)); @@ -307,15 +312,11 @@ where #[cfg(test)] mod tests { - use rand::{rng, rngs::ThreadRng}; + use rand::rng; use crate::{ - adapt_strategy::test_logps::NormalLogp, - chain::NutsChain, - cpu_math::CpuMath, + Chain, Settings, adapt_strategy::test_logps::NormalLogp, cpu_math::CpuMath, sampler::DiagGradNutsSettings, - sampler_stats::{SamplerStats, StatTraceBuilder}, - Chain, Settings, }; #[test] @@ -329,17 +330,12 @@ mod tests { let mut chain = settings.new_chain(0, math, &mut rng); - let opt_settings = settings.stats_options(&chain); - let mut builder = chain.new_builder(opt_settings, &settings, ndim); - let (_, mut progress) = chain.draw().unwrap(); for _ in 0..10 { let (_, prog) = chain.draw().unwrap(); progress = prog; - builder.append_value(None, &chain); } assert!(!progress.diverging); - StatTraceBuilder::<_, NutsChain<_, ThreadRng, _>>::finalize(builder); } } diff --git a/src/sampler.rs b/src/sampler.rs index 0bb4f7c..0ccb9ca 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -1,38 +1,45 @@ -use anyhow::{bail, Context, Result}; -use arrow::array::Array; +use anyhow::{Context, Result, bail}; use itertools::Itertools; -use rand::{rngs::SmallRng, Rng, SeedableRng}; +use nuts_storable::{HasDims, Storable, Value}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; use rand_chacha::ChaCha8Rng; use rayon::{ScopeFifo, ThreadPoolBuilder}; +use serde::Serialize; use std::{ + collections::HashMap, fmt::Debug, + ops::Deref, sync::{ + Arc, Mutex, mpsc::{ - channel, sync_channel, Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, + Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel, }, - Arc, Mutex, }, - thread::{spawn, JoinHandle}, + thread::{JoinHandle, spawn}, time::{Duration, Instant}, }; use crate::{ - adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy}, + DiagAdaptExpSettings, + adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions}, chain::{AdaptStrategy, Chain, NutsChain, StatOptions}, euclidean_hamiltonian::EuclideanHamiltonian, - low_rank_mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings}, mass_matrix::DiagMassMatrix, - mass_matrix_adapt::Strategy as DiagMassMatrixStrategy, + mass_matrix::Strategy as DiagMassMatrixStrategy, + mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings}, math_base::Math, + model::Model, nuts::NutsOptions, - sampler_stats::{SamplerStats, StatTraceBuilder}, + sampler_stats::{SamplerStats, StatsDims}, + storage::{ChainStorage, StorageConfig, TraceStorage}, transform_adapt_strategy::{TransformAdaptation, TransformedSettings}, transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions}, - DiagAdaptExpSettings, }; /// All sampler configurations implement this trait -pub trait Settings: private::Sealed + Clone + Copy + Default + Sync + Send + 'static { +pub trait Settings: + private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + 'static +{ type Chain: Chain; fn new_chain( @@ -46,10 +53,83 @@ pub trait Settings: private::Sealed + Clone + Copy + Default + Sync + Send + 'st fn hint_num_draws(&self) -> usize; fn num_chains(&self) -> usize; fn seed(&self) -> u64; - fn stats_options( - &self, - chain: &Self::Chain, - ) -> as SamplerStats>::StatOptions; + fn stats_options(&self) -> as SamplerStats>::StatsOptions; + + fn stat_names(&self, math: &M) -> Vec { + let dims = StatsDims::from(math); + < as SamplerStats>::Stats as Storable<_>>::names(&dims) + .into_iter() + .map(String::from) + .collect() + } + + fn data_names(&self, math: &M) -> Vec { + >::names(math) + .into_iter() + .map(String::from) + .collect() + } + + fn stat_types(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> { + self.stat_names(math) + .into_iter() + .map(|name| (name.clone(), self.stat_type::(math, &name))) + .collect() + } + + fn stat_type(&self, math: &M, name: &str) -> nuts_storable::ItemType { + let dims = StatsDims::from(math); + < as SamplerStats>::Stats as Storable<_>>::item_type(&dims, name) + } + + fn data_types(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> { + self.data_names(math) + .into_iter() + .map(|name| (name.clone(), self.data_type(math, &name))) + .collect() + } + fn data_type(&self, math: &M, name: &str) -> nuts_storable::ItemType { + >::item_type(math, name) + } + + fn stat_dims_all(&self, math: &M) -> Vec<(String, Vec)> { + self.stat_names(math) + .into_iter() + .map(|name| (name.clone(), self.stat_dims::(math, &name))) + .collect() + } + + fn stat_dims(&self, math: &M, name: &str) -> Vec { + let dims = StatsDims::from(math); + < as SamplerStats>::Stats as Storable<_>>::dims(&dims, name) + .into_iter() + .map(String::from) + .collect() + } + + fn stat_dim_sizes(&self, math: &M) -> HashMap { + let dims = StatsDims::from(math); + dims.dim_sizes() + } + + fn data_dims_all(&self, math: &M) -> Vec<(String, Vec)> { + self.data_names(math) + .into_iter() + .map(|name| (name.clone(), self.data_dims(math, &name))) + .collect() + } + + fn data_dims(&self, math: &M, name: &str) -> Vec { + >::dims(math, name) + .into_iter() + .map(String::from) + .collect() + } + + fn stat_coords(&self, math: &M) -> HashMap { + let dims = StatsDims::from(math); + dims.coords() + } } #[derive(Debug, Clone)] @@ -78,8 +158,8 @@ mod private { } /// Settings for the NUTS sampler -#[derive(Debug, Clone, Copy)] -pub struct NutsSettings { +#[derive(Debug, Clone, Copy, Serialize)] +pub struct NutsSettings { /// The number of tuning steps, where we fit the step size and mass matrix. pub num_tune: u64, /// The number of draws after tuning @@ -87,6 +167,9 @@ pub struct NutsSettings { /// The maximum tree depth during sampling. The number of leapfrog steps /// is smaller than 2 ^ maxdepth. pub maxdepth: u64, + /// The minimum tree depth during sampling. The number of leapfrog steps + /// is larger than 2 ^ mindepth. + pub mindepth: u64, /// Store the gradient in the SampleStats pub store_gradient: bool, /// Store each unconstrained parameter vector in the sampler stats @@ -114,6 +197,7 @@ impl Default for DiagGradNutsSettings { num_tune: 400, num_draws: 1000, maxdepth: 10, + mindepth: 0, max_energy_error: 1000f64, store_gradient: false, store_unconstrained: false, @@ -132,6 +216,7 @@ impl Default for LowRankNutsSettings { num_tune: 800, num_draws: 1000, maxdepth: 10, + mindepth: 0, max_energy_error: 1000f64, store_gradient: false, store_unconstrained: false, @@ -152,6 +237,7 @@ impl Default for TransformedNutsSettings { num_tune: 1500, num_draws: 1000, maxdepth: 10, + mindepth: 0, max_energy_error: 20f64, store_gradient: false, store_unconstrained: false, @@ -187,6 +273,7 @@ impl Settings for LowRankNutsSettings { let options = NutsOptions { maxdepth: self.maxdepth, + mindepth: self.mindepth, store_gradient: self.store_gradient, store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, @@ -195,7 +282,15 @@ impl Settings for LowRankNutsSettings { let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); - NutsChain::new(math, potential, strategy, options, rng, chain) + NutsChain::new( + math, + potential, + strategy, + options, + rng, + chain, + self.stats_options(), + ) } fn hint_num_tune(&self) -> usize { @@ -214,12 +309,12 @@ impl Settings for LowRankNutsSettings { self.seed } - fn stats_options( - &self, - _chain: &Self::Chain, - ) -> as SamplerStats>::StatOptions { + fn stats_options(&self) -> as SamplerStats>::StatsOptions { StatOptions { - adapt: (), + adapt: GlobalStrategyStatsOptions { + mass_matrix: (), + step_size: (), + }, hamiltonian: (), point: (), } @@ -246,6 +341,7 @@ impl Settings for DiagGradNutsSettings { let options = NutsOptions { maxdepth: self.maxdepth, + mindepth: self.mindepth, store_gradient: self.store_gradient, store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, @@ -254,7 +350,15 @@ impl Settings for DiagGradNutsSettings { let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); - NutsChain::new(math, potential, strategy, options, rng, chain) + NutsChain::new( + math, + potential, + strategy, + options, + rng, + chain, + self.stats_options(), + ) } fn hint_num_tune(&self) -> usize { @@ -273,12 +377,12 @@ impl Settings for DiagGradNutsSettings { self.seed } - fn stats_options( - &self, - _chain: &Self::Chain, - ) -> as SamplerStats>::StatOptions { + fn stats_options(&self) -> as SamplerStats>::StatsOptions { StatOptions { - adapt: (), + adapt: GlobalStrategyStatsOptions { + mass_matrix: (), + step_size: (), + }, hamiltonian: (), point: (), } @@ -302,6 +406,7 @@ impl Settings for TransformedNutsSettings { let options = NutsOptions { maxdepth: self.maxdepth, + mindepth: self.mindepth, store_gradient: self.store_gradient, store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, @@ -309,7 +414,15 @@ impl Settings for TransformedNutsSettings { }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); - NutsChain::new(math, hamiltonian, strategy, options, rng, chain) + NutsChain::new( + math, + hamiltonian, + strategy, + options, + rng, + chain, + self.stats_options(), + ) } fn hint_num_tune(&self) -> usize { @@ -328,10 +441,7 @@ impl Settings for TransformedNutsSettings { self.seed } - fn stats_options( - &self, - _chain: &Self::Chain, - ) -> as SamplerStats>::StatOptions { + fn stats_options(&self) -> as SamplerStats>::StatsOptions { // TODO make extra config let point = TransformedPointStatsOptions { store_transformed: self.store_unconstrained, @@ -357,30 +467,6 @@ pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>( Ok((0..draws).map(move |_| sampler.draw())) } -pub trait DrawStorage: Send { - fn append_value(&mut self, point: &[f64]) -> Result<()>; - fn finalize(self) -> Result>; - fn inspect(&self) -> Result>; -} - -pub trait Model: Send + Sync + 'static { - type Math<'model>: Math - where - Self: 'model; - type DrawStorage<'model, S: Settings>: DrawStorage - where - Self: 'model; - - fn new_trace<'model, S: Settings, R: Rng + ?Sized>( - &'model self, - rng: &mut R, - chain_id: u64, - settings: &'model S, - ) -> Result>; - fn math(&self) -> Result>; - fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()>; -} - #[non_exhaustive] #[derive(Clone, Debug)] pub struct ChainProgress { @@ -427,76 +513,34 @@ impl ChainProgress { } } -pub struct ChainOutput { - pub draws: Arc, - pub stats: Arc, - pub chain_id: u64, -} - enum ChainCommand { Resume, Pause, } -struct ChainTrace<'model, M: Model + 'model, S: Settings> { - draws_builder: M::DrawStorage<'model, S>, - stats_builder: > as SamplerStats>>::Builder, - chain_id: u64, -} - -impl<'model, M: Model + 'model, S: Settings> ChainTrace<'model, M, S> { - fn inspect(&self) -> Result { - let stats = self.stats_builder.inspect().expect("No sample stats"); - let draws = self.draws_builder.inspect()?; - Ok(ChainOutput { - chain_id: self.chain_id, - draws, - stats: Arc::new(stats) as Arc<_>, - }) - } - - fn finalize(self) -> Result { - let draws = self.draws_builder.finalize()?; - let stats = self.stats_builder.finalize().expect("No sample stats"); - Ok(ChainOutput { - chain_id: self.chain_id, - draws, - stats: Arc::new(stats) as Arc, - }) - } -} - -struct ChainProcess<'model, M, S> +struct ChainProcess where - M: Model + 'model, - S: Settings, + T: TraceStorage, { stop_marker: Sender, - trace: Arc>>>, + trace: Arc>>, progress: Arc>, } -impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { - fn finalize_many(chains: Vec) -> Vec>> { - chains +impl ChainProcess { + fn finalize_many(trace: T, chains: Vec) -> Result<(Option, T::Finalized)> { + let finalized_chain_traces = chains .into_iter() + .filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take()) .map(|chain| chain.finalize()) - .collect_vec() + .collect_vec(); + trace.finalize(finalized_chain_traces) } fn progress(&self) -> ChainProgress { self.progress.lock().expect("Poisoned lock").clone() } - fn current_trace(&self) -> Result> { - self.trace - .lock() - .expect("Poisoned lock") - .as_ref() - .map(|trace| trace.inspect()) - .transpose() - } - fn resume(&self) -> Result<()> { self.stop_marker.send(ChainCommand::Resume)?; Ok(()) @@ -507,62 +551,40 @@ impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { Ok(()) } - fn finalize(self) -> Result> { - drop(self.stop_marker); - self.trace - .lock() - .expect("Poisoned lock") - .take() - .map(|trace| trace.finalize()) - .transpose() - } - - fn start<'model>( + fn start<'model, M: Model, S: Settings>( model: &'model M, + chain_trace: T::ChainStorage, chain_id: u64, seed: u64, settings: &'model S, - scope: &ScopeFifo<'scope>, + scope: &ScopeFifo<'model>, results: Sender>, - ) -> Result - where - 'model: 'scope, - { + ) -> Result { let (stop_marker_tx, stop_marker_rx) = channel(); let mut rng = ChaCha8Rng::seed_from_u64(seed); - rng.set_stream(chain_id); + rng.set_stream(chain_id + 1); - let trace = Arc::new(Mutex::new(None)); + let chain_trace = Arc::new(Mutex::new(Some(chain_trace))); let progress = Arc::new(Mutex::new(ChainProgress::new( settings.hint_num_draws() + settings.hint_num_tune(), ))); - let trace_inner = trace.clone(); + let trace_inner = chain_trace.clone(); let progress_inner = progress.clone(); scope.spawn_fifo(move |_| { - let trace = trace_inner; + let chain_trace = trace_inner; let progress = progress_inner; let mut sample = move || { - let logp = model.math().context("Failed to create model density")?; + let logp = model + .math(&mut rng) + .context("Failed to create model density")?; let dim = logp.dim(); let mut sampler = settings.new_chain(chain_id, logp, &mut rng); - let draw_trace = model - .new_trace(&mut rng, chain_id, settings) - .context("Failed to create trace object")?; - let stat_opts = settings.stats_options(&sampler); - let stats_trace = sampler.new_builder(stat_opts, settings, dim); - - let new_trace = ChainTrace { - draws_builder: draw_trace, - stats_builder: stats_trace, - chain_id, - }; - *trace.lock().expect("Poisoned mutex") = Some(new_trace); progress.lock().expect("Poisoned mutex").started = true; let mut initval = vec![0f64; dim]; @@ -603,12 +625,14 @@ impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { } let now = Instant::now(); - let (point, info) = sampler.draw().unwrap(); - let mut guard = trace + //let (point, info) = sampler.draw().unwrap(); + let (_point, draw_data, stats, info) = sampler.expanded_draw().unwrap(); + + let mut guard = chain_trace .lock() .expect("Could not unlock trace lock. Poisoned mutex"); - let Some(val) = guard.as_mut() else { + let Some(trace_val) = guard.as_mut() else { // The trace was removed by controller thread. We can stop sampling break; }; @@ -616,8 +640,16 @@ impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { .lock() .expect("Poisoned mutex") .update(&info, now.elapsed()); - DrawStorage::append_value(&mut val.draws_builder, &point)?; - StatTraceBuilder::append_value(&mut val.stats_builder, None, &sampler); + + let math = sampler.math(); + let dims = StatsDims::from(math.deref()); + trace_val.record_sample( + settings, + stats.get_all(&dims), + draw_data.get_all(math.deref()), + &info, + )?; + draw += 1; if draw == draws { break; @@ -630,71 +662,75 @@ impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { let result = sample(); - results - .send(result) - .expect("Could not send sampling results to main thread."); + // We intentionally ignore errors here, because this means some other + // chain already failed, and should have reported the error. + let _ = results.send(result); drop(results); }); Ok(Self { - trace, + trace: chain_trace, stop_marker: stop_marker_tx, progress, }) } + + fn flush(&self) -> Result<()> { + self.trace + .lock() + .map_err(|_| anyhow::anyhow!("Could not lock trace mutex")) + .context("Could not flush trace")? + .as_mut().map(|v| v.flush()) + .transpose()?; + Ok(()) + } } #[derive(Debug)] enum SamplerCommand { Pause, Continue, - InspectTrace, Progress, + Flush, } enum SamplerResponse { Ok(), - IntermediateTrace(Trace), Progress(Box<[ChainProgress]>), } -pub enum SamplerWaitResult { - Trace(Trace), - Timeout(Sampler), - Err(anyhow::Error, Option), +pub enum SamplerWaitResult { + Trace(F), + Timeout(Sampler), + Err(anyhow::Error, Option), } -pub struct Sampler { - main_thread: JoinHandle>>>>, +pub struct Sampler { + main_thread: JoinHandle, F)>>, commands: SyncSender, responses: Receiver, results: Receiver>, } -pub struct Trace { - pub chains: Vec, -} - -impl> From for Trace { - fn from(value: I) -> Self { - let mut chains = value.into_iter().collect_vec(); - chains.sort_unstable_by_key(|x| x.chain_id); - Trace { chains } - } -} - pub struct ProgressCallback { pub callback: Box) + Send>, pub rate: Duration, } -impl Sampler { - pub fn new( +impl Sampler { + pub fn new( model: M, settings: S, + trace_config: C, num_cores: usize, callback: Option, - ) -> Result { + ) -> Result + where + S: Settings, + C: StorageConfig, + M: Model, + T: TraceStorage, + { let (commands_tx, commands_rx) = sync_channel(0); let (responses_tx, responses_rx) = sync_channel(0); let (results_tx, results_rx) = channel(); @@ -714,9 +750,24 @@ impl Sampler { let results = results_tx; let mut chains = Vec::with_capacity(settings.num_chains()); + let mut rng = ChaCha8Rng::seed_from_u64(settings.seed()); + rng.set_stream(0); + + let math = model_ref + .math(&mut rng) + .context("Could not create model density")?; + let trace = trace_config + .new_trace(settings_ref, &math) + .context("Could not create trace object")?; + drop(math); + for chain_id in 0..settings.num_chains() { + let chain_trace_val = trace + .initialize_trace_for_chain(chain_id as u64) + .context("Failed to create trace object")?; let chain = ChainProcess::start( model_ref, + chain_trace_val, chain_id as u64, settings.seed(), settings_ref, @@ -729,7 +780,7 @@ impl Sampler { let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result(); if let Some(error) = errors.into_iter().next() { - let _ = ChainProcess::finalize_many(chains); + let _ = ChainProcess::finalize_many(trace, chains); return Err(error).context("Could not start chains"); } @@ -787,18 +838,17 @@ impl Sampler { is_paused = false; responses_tx.send(SamplerResponse::Ok())?; } - Ok(SamplerCommand::InspectTrace) => { - let traces: Result> = - chains.iter().map(|chain| chain.current_trace()).collect(); - responses_tx.send(SamplerResponse::IntermediateTrace( - traces?.into_iter().flatten().into(), - ))?; - } Ok(SamplerCommand::Progress) => { let progress = chains.iter().map(|chain| chain.progress()).collect_vec(); responses_tx.send(SamplerResponse::Progress(progress.into()))?; } + Ok(SamplerCommand::Flush) => { + for chain in chains.iter() { + chain.flush()?; + } + responses_tx.send(SamplerResponse::Ok())?; + } Err(RecvTimeoutError::Timeout) => {} Err(RecvTimeoutError::Disconnected) => { if let Some(ProgressCallback { callback, .. }) = &mut callback { @@ -818,10 +868,10 @@ impl Sampler { }; let result: Result<()> = main_loop(); // Run finalization even if something failed - let output = Ok(ChainProcess::finalize_many(chains)); + let output = ChainProcess::finalize_many(trace, chains)?; result?; - output + Ok(output) }) }); @@ -856,46 +906,40 @@ impl Sampler { Ok(()) } - pub fn abort(self) -> (Result<()>, Option) { + pub fn flush(&mut self) -> Result<()> { + self.commands.send(SamplerCommand::Flush)?; + let response = self + .responses + .recv() + .context("Could not recieve flush response from controller thread")?; + let SamplerResponse::Ok() = response else { + bail!("Got invalid response from sample controller thread"); + }; + Ok(()) + } + + pub fn abort(self) -> Result<(Option, F)> { drop(self.commands); let result = self.main_thread.join(); match result { Err(payload) => std::panic::resume_unwind(payload), - Ok(Ok(traces)) => { - let (traces, errors): (Vec<_>, Vec<_>) = traces.into_iter().partition_result(); - let trace: Trace = traces.into_iter().flatten().into(); - match errors.into_iter().next() { - Some(err) => (Err(err), Some(trace)), - None => (Ok(()), Some(trace)), - } - } - Ok(Err(err)) => (Err(err), None), + Ok(Ok(val)) => Ok(val), + Ok(Err(err)) => Err(err), } } - pub fn inspect_trace(&mut self) -> Result { - self.commands.send(SamplerCommand::InspectTrace)?; - let response = self.responses.recv()?; - let SamplerResponse::IntermediateTrace(trace) = response else { - bail!("Got invalid response from sample controller thread"); - }; - Ok(trace) - } - - pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult { + pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult { let start = Instant::now(); let mut remaining = Some(timeout); while remaining.is_some() { match self.results.recv_timeout(timeout) { Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()), Ok(Err(e)) => return SamplerWaitResult::Err(e, None), - Err(RecvTimeoutError::Disconnected) => { - let (res, trace) = self.abort(); - if let Err(err) = res { - return SamplerWaitResult::Err(err, trace); - } - return SamplerWaitResult::Trace(trace.expect("No chains available")); - } + Err(RecvTimeoutError::Disconnected) => match self.abort() { + Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)), + Ok((None, trace)) => return SamplerWaitResult::Trace(trace), + Err(err) => return SamplerWaitResult::Err(err, None), + }, Err(RecvTimeoutError::Timeout) => break, } } @@ -915,22 +959,18 @@ impl Sampler { #[cfg(test)] pub mod test_logps { - use std::sync::Arc; + use std::collections::HashMap; use crate::{ + Model, cpu_math::{CpuLogpFunc, CpuMath}, math_base::LogpError, - Settings, }; use anyhow::Result; - use arrow::{ - array::{Array, ArrayBuilder, FixedSizeListBuilder, PrimitiveBuilder}, - datatypes::Float64Type, - }; + use nuts_storable::HasDims; + use rand::Rng; use thiserror::Error; - use super::{DrawStorage, Model}; - #[derive(Clone, Debug)] pub struct NormalLogp { pub dim: usize, @@ -946,9 +986,21 @@ pub mod test_logps { } } + impl HasDims for &NormalLogp { + fn dim_sizes(&self) -> HashMap { + vec![ + ("unconstrained_parameter".to_string(), self.dim as u64), + ("dim".to_string(), self.dim as u64), + ] + .into_iter() + .collect() + } + } + impl CpuLogpFunc for &NormalLogp { type LogpError = NormalLogpError; - type TransformParams = (); + type FlowParameters = (); + type ExpandedVector = Vec; fn dim(&self) -> usize { self.dim @@ -968,9 +1020,20 @@ pub mod test_logps { Ok(logp) } + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> std::result::Result + where + R: rand::Rng + ?Sized, + { + Ok(array.to_vec()) + } + fn inv_transform_normalize( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &[f64], _untransofrmed_gradient: &[f64], _transformed_position: &mut [f64], @@ -981,7 +1044,7 @@ pub mod test_logps { fn init_from_untransformed_position( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &[f64], _untransformed_gradient: &mut [f64], _transformed_position: &mut [f64], @@ -992,7 +1055,7 @@ pub mod test_logps { fn init_from_transformed_position( &mut self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, _untransformed_position: &mut [f64], _untransformed_gradient: &mut [f64], _transformed_position: &[f64], @@ -1007,7 +1070,7 @@ pub mod test_logps { _untransformed_positions: impl Iterator, _untransformed_gradients: impl Iterator, _untransformed_logp: impl Iterator, - _params: &'b mut Self::TransformParams, + _params: &'b mut Self::FlowParameters, ) -> std::result::Result<(), Self::LogpError> { unimplemented!() } @@ -1018,46 +1081,18 @@ pub mod test_logps { _untransformed_position: &[f64], _untransfogmed_gradient: &[f64], _chain: u64, - ) -> std::result::Result { + ) -> std::result::Result { unimplemented!() } fn transformation_id( &self, - _params: &Self::TransformParams, + _params: &Self::FlowParameters, ) -> std::result::Result { unimplemented!() } } - pub struct SimpleDrawStorage { - draws: FixedSizeListBuilder>, - } - - impl SimpleDrawStorage { - pub fn new(size: usize) -> Self { - let items = PrimitiveBuilder::new(); - let draws = FixedSizeListBuilder::new(items, size as _); - Self { draws } - } - } - - impl DrawStorage for SimpleDrawStorage { - fn append_value(&mut self, point: &[f64]) -> Result<()> { - self.draws.values().append_slice(point); - self.draws.append(true); - Ok(()) - } - - fn finalize(mut self) -> Result> { - Ok(ArrayBuilder::finish(&mut self.draws)) - } - - fn inspect(&self) -> Result> { - Ok(ArrayBuilder::finish_cloned(&self.draws)) - } - } - pub struct CpuModel { logp: F, } @@ -1068,24 +1103,14 @@ pub mod test_logps { } } - impl Model for CpuModel + impl Model for CpuModel where + F: Send + Sync + 'static, for<'a> &'a F: CpuLogpFunc, { type Math<'model> = CpuMath<&'model F>; - type DrawStorage<'model, S: Settings> = SimpleDrawStorage; - - fn new_trace<'model, S: Settings, R: rand::prelude::Rng + ?Sized>( - &'model self, - _rng: &mut R, - _chain_id: u64, - _settings: &'model S, - ) -> Result> { - Ok(SimpleDrawStorage::new((&self.logp).dim())) - } - - fn math(&self) -> Result> { + fn math(&self, _rng: &mut R) -> Result> { Ok(CpuMath::new(&self.logp)) } @@ -1102,20 +1127,24 @@ pub mod test_logps { #[cfg(test)] mod tests { - use std::time::{Duration, Instant}; + use std::{ + sync::Arc, + time::{Duration, Instant}, + }; use super::test_logps::NormalLogp; use crate::{ + Chain, DiagGradNutsSettings, Sampler, ZarrConfig, cpu_math::CpuMath, sample_sequentially, - sampler::{test_logps::CpuModel, Settings}, - Chain, DiagGradNutsSettings, Sampler, + sampler::{Settings, test_logps::CpuModel}, }; use anyhow::Result; use itertools::Itertools; use pretty_assertions::assert_eq; - use rand::{rngs::StdRng, SeedableRng}; + use rand::{SeedableRng, rngs::StdRng}; + use zarrs::storage::store::MemoryStore; #[test] fn sample_chain() -> Result<()> { @@ -1160,28 +1189,37 @@ mod tests { }; let model = CpuModel::new(logp.clone()); - let mut sampler = Sampler::new(model, settings, 4, None)?; + let store = MemoryStore::new(); + + let zarr_config = ZarrConfig::new(Arc::new(store)); + let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?; sampler.pause()?; sampler.pause()?; - let _trace = sampler.inspect_trace()?; + // TODO flush trace sampler.resume()?; - let (ok, trace) = sampler.abort(); - ok?; - assert!(trace.expect("No trace").chains.len() <= settings.num_chains); + let (ok, _) = sampler.abort()?; + if let Some(err) = ok { + Err(err)?; + } + let store = MemoryStore::new(); + let zarr_config = ZarrConfig::new(Arc::new(store)); let model = CpuModel::new(logp.clone()); - let mut sampler = Sampler::new(model, settings, 4, None)?; + let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?; sampler.pause()?; - sampler.abort().0?; + if let (Some(err), _) = sampler.abort()? { + Err(err)?; + } + let store = MemoryStore::new(); + let zarr_config = ZarrConfig::new(Arc::new(store)); let model = CpuModel::new(logp.clone()); let start = Instant::now(); - let sampler = Sampler::new(model, settings, 4, None)?; + let sampler = Sampler::new(model, settings, zarr_config, 4, None)?; let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) { - super::SamplerWaitResult::Trace(trace) => { + super::SamplerWaitResult::Trace(_) => { dbg!(start.elapsed()); - assert!(trace.chains.len() == settings.num_chains); panic!("finished"); } super::SamplerWaitResult::Timeout(sampler) => sampler, @@ -1195,12 +1233,8 @@ mod tests { } match sampler.wait_timeout(Duration::from_secs(1)) { - super::SamplerWaitResult::Trace(trace) => { + super::SamplerWaitResult::Trace(_) => { dbg!(start.elapsed()); - assert!(trace.chains.len() == settings.num_chains); - trace.chains.iter().for_each(|chain| { - assert!(chain.draws.len() as u64 == settings.num_tune + settings.num_draws); - }); } super::SamplerWaitResult::Timeout(_) => { panic!("timeout") diff --git a/src/sampler_stats.rs b/src/sampler_stats.rs index 3b3b457..e2bed96 100644 --- a/src/sampler_stats.rs +++ b/src/sampler_stats.rs @@ -1,33 +1,40 @@ -use arrow::array::StructArray; +use std::collections::HashMap; -use crate::{Math, Settings}; +use nuts_storable::{HasDims, Storable, Value}; -pub trait SamplerStats { - type Builder: StatTraceBuilder; - type StatOptions; - - fn new_builder( - &self, - options: Self::StatOptions, - settings: &impl Settings, - dim: usize, - ) -> Self::Builder; -} +use crate::Math; -pub trait StatTraceBuilder: Send { - fn append_value(&mut self, math: Option<&mut M>, value: &T); - fn finalize(self) -> Option; - fn inspect(&self) -> Option; +#[derive(Clone)] +pub struct StatsDims { + n_dim: u64, + coord: Option, } -impl StatTraceBuilder for () { - fn append_value(&mut self, _math: Option<&mut M>, _value: &T) {} +impl HasDims for StatsDims { + fn dim_sizes(&self) -> std::collections::HashMap { + std::collections::HashMap::from([("unconstrained_parameter".to_string(), self.n_dim)]) + } - fn finalize(self) -> Option { - None + fn coords(&self) -> HashMap { + if let Some(coord) = &self.coord { + return HashMap::from([("unconstrained_parameter".to_string(), coord.clone())]); + } + HashMap::new() } +} - fn inspect(&self) -> Option { - None +impl From<&M> for StatsDims { + fn from(math: &M) -> Self { + StatsDims { + n_dim: math.dim() as u64, + coord: math.vector_coord(), + } } } + +pub trait SamplerStats { + type Stats: Storable; + type StatsOptions: Copy + Send + Sync; + + fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats; +} diff --git a/src/state.rs b/src/state.rs index 94f277d..c680924 100644 --- a/src/state.rs +++ b/src/state.rs @@ -104,11 +104,10 @@ impl> State { impl> Drop for State { fn drop(&mut self) { let rc = unsafe { std::mem::ManuallyDrop::take(&mut self.inner) }; - if (Rc::strong_count(&rc) == 1) & (Rc::weak_count(&rc) == 0) { - if let Some(storage) = rc.reuser.upgrade() { + if (Rc::strong_count(&rc) == 1) & (Rc::weak_count(&rc) == 0) + && let Some(storage) = rc.reuser.upgrade() { storage.free_states.borrow_mut().push(rc); } - } } } diff --git a/src/stepsize/adam.rs b/src/stepsize/adam.rs new file mode 100644 index 0000000..dc88e8c --- /dev/null +++ b/src/stepsize/adam.rs @@ -0,0 +1,112 @@ +//! Adam optimizer for step size adaptation. +//! +//! This implements a single-parameter version of the Adam optimizer +//! for adapting the step size in the NUTS algorithm. Unlike dual averaging, +//! Adam maintains both first and second moment estimates of gradients, +//! which can potentially lead to better adaptation in some scenarios. + +use std::f64; + +use serde::Serialize; + +/// Settings for Adam step size adaptation +#[derive(Debug, Clone, Copy, Serialize)] +pub struct AdamOptions { + /// First moment decay rate (default: 0.9) + pub beta1: f64, + /// Second moment decay rate (default: 0.999) + pub beta2: f64, + /// Small constant for numerical stability (default: 1e-8) + pub epsilon: f64, + /// Learning rate (default: 0.001) + pub learning_rate: f64, +} + +impl Default for AdamOptions { + fn default() -> Self { + Self { + beta1: 0.9, + beta2: 0.999, + epsilon: 1e-8, + learning_rate: 0.05, + } + } +} + +/// Adam optimizer for step size adaptation. +/// +/// This implements the Adam optimizer for a single parameter (the step size). +/// The adaptation takes the acceptance probability statistic and adjusts +/// the step size to reach the target acceptance rate. +#[derive(Clone)] +pub struct Adam { + /// Current log step size + log_step: f64, + /// First moment estimate + m: f64, + /// Second moment estimate + v: f64, + /// Iteration counter + t: u64, + /// Adam settings + settings: AdamOptions, +} + +impl Adam { + /// Create a new Adam optimizer with given settings and initial step size + pub fn new(settings: AdamOptions, initial_step: f64) -> Self { + Self { + log_step: initial_step.ln(), + m: 0.0, + v: 0.0, + t: 0, + settings, + } + } + + /// Advance the optimizer by one step using the current acceptance statistic + /// + /// This updates the step size to move towards the target acceptance rate. + /// The error signal is the difference between the target and current acceptance rates. + pub fn advance(&mut self, accept_stat: f64, target: f64) { + // Compute the error/gradient - we want to minimize (target - accept_stat)² + // So gradient is -2 * (target - accept_stat) + // We simplify and just use (accept_stat - target) as our gradient + let gradient = accept_stat - target; + + // Increment timestep + self.t += 1; + + // Update biased first moment estimate + self.m = self.settings.beta1 * self.m + (1.0 - self.settings.beta1) * gradient; + + // Update biased second moment estimate + self.v = self.settings.beta2 * self.v + (1.0 - self.settings.beta2) * gradient * gradient; + + // Compute bias-corrected first moment estimate + let m_hat = self.m / (1.0 - self.settings.beta1.powi(self.t as i32)); + + // Compute bias-corrected second moment estimate + let v_hat = self.v / (1.0 - self.settings.beta2.powi(self.t as i32)); + + // Update log step size + // Note: if gradient is positive (accept_stat > target), we should decrease step size + // if gradient is negative (accept_stat < target), we should increase step size + self.log_step += + self.settings.learning_rate * m_hat / (v_hat.sqrt() + self.settings.epsilon); + } + + /// Get the current step size (not adapted) + pub fn current_step_size(&self) -> f64 { + self.log_step.exp() + } + + /// Reset the optimizer with a new initial step size and bias factor + #[allow(dead_code)] + pub fn reset(&mut self, initial_step: f64, _bias_factor: f64) { + self.log_step = initial_step.ln(); + self.m = 0.0; + self.v = 0.0; + self.t = 0; + } +} diff --git a/src/stepsize/adapt.rs b/src/stepsize/adapt.rs new file mode 100644 index 0000000..7dc9134 --- /dev/null +++ b/src/stepsize/adapt.rs @@ -0,0 +1,304 @@ +use itertools::Either; +use nuts_derive::Storable; +use rand::Rng; +use rand_distr::Uniform; +use serde::Serialize; + +use super::adam::{Adam, AdamOptions}; +use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions}; +use crate::{ + Math, NutsError, + hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, + nuts::{Collector, NutsOptions}, + sampler_stats::SamplerStats, +}; +use std::fmt::Debug; + +/// Method used for step size adaptation +#[derive(Debug, Clone, Copy, Serialize, Default)] +pub enum StepSizeAdaptMethod { + /// Use dual averaging for step size adaptation (default) + #[default] + DualAverage, + /// Use Adam optimizer for step size adaptation + Adam, + Fixed(f64), +} + +/// Options for step size adaptation +#[derive(Debug, Clone, Copy, Serialize)] +pub struct StepSizeAdaptOptions { + pub method: StepSizeAdaptMethod, + /// Dual averaging adaptation options + pub dual_average: DualAverageOptions, + /// Adam optimizer adaptation options + pub adam: AdamOptions, +} + +impl Default for StepSizeAdaptOptions { + fn default() -> Self { + Self { + method: StepSizeAdaptMethod::DualAverage, + dual_average: DualAverageOptions::default(), + adam: AdamOptions::default(), + } + } +} + +/// Step size adaptation strategy +pub struct Strategy { + /// The step size adaptation method being used + adaptation: Option>, + /// Settings for step size adaptation + options: StepSizeSettings, + /// Last mean tree accept rate + pub last_mean_tree_accept: f64, + /// Last symmetric mean tree accept rate + pub last_sym_mean_tree_accept: f64, + /// Last number of steps + pub last_n_steps: u64, +} + +impl Strategy { + pub fn new(options: StepSizeSettings) -> Self { + let adaptation = match options.adapt_options.method { + StepSizeAdaptMethod::DualAverage => Some(Either::Left(DualAverage::new( + options.adapt_options.dual_average, + options.initial_step, + ))), + StepSizeAdaptMethod::Adam => Some(Either::Right(Adam::new( + options.adapt_options.adam, + options.initial_step, + ))), + StepSizeAdaptMethod::Fixed(_) => None, + }; + + Self { + adaptation, + options, + last_n_steps: 0, + last_sym_mean_tree_accept: 0.0, + last_mean_tree_accept: 0.0, + } + } + + pub fn init>( + &mut self, + math: &mut M, + options: &mut NutsOptions, + hamiltonian: &mut impl Hamiltonian, + position: &[f64], + rng: &mut R, + ) -> Result<(), NutsError> { + if let StepSizeAdaptMethod::Fixed(step_size) = self.options.adapt_options.method { + *hamiltonian.step_size_mut() = step_size; + return Ok(()); + }; + let mut state = hamiltonian.init_state(math, position)?; + hamiltonian.initialize_trajectory(math, &mut state, rng)?; + + let mut collector = AcceptanceRateCollector::new(); + + collector.register_init(math, &state, options); + + *hamiltonian.step_size_mut() = self.options.initial_step; + + let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); + + let LeapfrogResult::Ok(_) = state_next else { + return Ok(()); + }; + + let accept_stat = collector.mean.current(); + let dir = if accept_stat > self.options.target_accept { + Direction::Forward + } else { + Direction::Backward + }; + + for _ in 0..100 { + let mut collector = AcceptanceRateCollector::new(); + collector.register_init(math, &state, options); + let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); + let LeapfrogResult::Ok(_) = state_next else { + *hamiltonian.step_size_mut() = self.options.initial_step; + return Ok(()); + }; + let accept_stat = collector.mean.current(); + match dir { + Direction::Forward => { + if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5) + { + match self.adaptation.as_mut().expect("Adaptation must be set") { + Either::Left(adapt) => { + *adapt = DualAverage::new( + self.options.adapt_options.dual_average, + hamiltonian.step_size(), + ); + } + Either::Right(adapt) => { + *adapt = Adam::new( + self.options.adapt_options.adam, + hamiltonian.step_size(), + ); + } + } + return Ok(()); + } + *hamiltonian.step_size_mut() *= 2.; + } + Direction::Backward => { + if (accept_stat >= self.options.target_accept) + | (hamiltonian.step_size() < 1e-10) + { + match self.adaptation.as_mut().expect("Adaptation must be set") { + Either::Left(adapt) => { + *adapt = DualAverage::new( + self.options.adapt_options.dual_average, + hamiltonian.step_size(), + ); + } + Either::Right(adapt) => { + *adapt = Adam::new( + self.options.adapt_options.adam, + hamiltonian.step_size(), + ); + } + } + return Ok(()); + } + *hamiltonian.step_size_mut() /= 2.; + } + } + } + // If we don't find something better, use the specified initial value + *hamiltonian.step_size_mut() = self.options.initial_step; + Ok(()) + } + + pub fn update(&mut self, collector: &AcceptanceRateCollector) { + let mean_sym = collector.mean_sym.current(); + let mean = collector.mean.current(); + let n_steps = collector.mean.count(); + self.last_mean_tree_accept = mean; + self.last_sym_mean_tree_accept = mean_sym; + self.last_n_steps = n_steps; + } + + pub fn update_estimator_early(&mut self) { + match self.adaptation.as_mut() { + None => {} + Some(Either::Left(adapt)) => { + adapt.advance(self.last_mean_tree_accept, self.options.target_accept); + } + Some(Either::Right(adapt)) => { + adapt.advance(self.last_mean_tree_accept, self.options.target_accept); + } + } + } + + pub fn update_estimator_late(&mut self) { + match self.adaptation.as_mut() { + None => {} + Some(Either::Left(adapt)) => { + adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept); + } + Some(Either::Right(adapt)) => { + adapt.advance(self.last_sym_mean_tree_accept, self.options.target_accept); + } + } + } + + pub fn update_stepsize( + &mut self, + rng: &mut R, + hamiltonian: &mut impl Hamiltonian, + use_best_guess: bool, + ) { + let step_size = match self.adaptation { + None => { + if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method { + val + } else { + panic!("Adaptation method must be Fixed if adaptation is None") + } + } + Some(Either::Left(ref adapt)) => { + if use_best_guess { + adapt.current_step_size_adapted() + } else { + adapt.current_step_size() + } + } + Some(Either::Right(ref adapt)) => adapt.current_step_size(), + }; + + if let Some(jitter) = self.options.jitter { + let jitter = + rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter")); + let jittered_step_size = step_size * jitter; + *hamiltonian.step_size_mut() = jittered_step_size; + } else { + *hamiltonian.step_size_mut() = step_size; + } + } + + pub fn new_collector(&self) -> AcceptanceRateCollector { + AcceptanceRateCollector::new() + } +} + +#[derive(Debug, Storable)] +pub struct Stats { + pub step_size_bar: f64, + pub mean_tree_accept: f64, + pub mean_tree_accept_sym: f64, + pub n_steps: u64, +} + +impl SamplerStats for Strategy { + type Stats = Stats; + type StatsOptions = (); + + fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { + Stats { + step_size_bar: match self.adaptation { + None => { + if let StepSizeAdaptMethod::Fixed(val) = self.options.adapt_options.method { + val + } else { + panic!("Adaptation method must be Fixed if adaptation is None") + } + } + Some(Either::Left(ref adapt)) => adapt.current_step_size_adapted(), + Some(Either::Right(ref adapt)) => adapt.current_step_size(), + }, + mean_tree_accept: self.last_mean_tree_accept, + mean_tree_accept_sym: self.last_sym_mean_tree_accept, + n_steps: self.last_n_steps, + } + } +} + +#[derive(Debug, Clone, Copy, Serialize)] +pub struct StepSizeSettings { + /// Target acceptance rate + pub target_accept: f64, + /// Initial step size + pub initial_step: f64, + /// Optional jitter to add to step size (randomization) + pub jitter: Option, + /// Adaptation options specific to the chosen method + pub adapt_options: StepSizeAdaptOptions, +} + +impl Default for StepSizeSettings { + fn default() -> Self { + Self { + target_accept: 0.8, + initial_step: 0.1, + jitter: Some(0.1), + adapt_options: StepSizeAdaptOptions::default(), + } + } +} diff --git a/src/stepsize.rs b/src/stepsize/dual_avg.rs similarity index 98% rename from src/stepsize.rs rename to src/stepsize/dual_avg.rs index 2556d8f..3f6d613 100644 --- a/src/stepsize.rs +++ b/src/stepsize/dual_avg.rs @@ -1,3 +1,5 @@ +use serde::Serialize; + use crate::{ hamiltonian::{DivergenceInfo, Point}, math_base::Math, @@ -6,7 +8,7 @@ use crate::{ }; /// Settings for step size adaptation -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize)] pub struct DualAverageOptions { pub k: f64, pub t0: f64, diff --git a/src/stepsize/mod.rs b/src/stepsize/mod.rs new file mode 100644 index 0000000..97ca3be --- /dev/null +++ b/src/stepsize/mod.rs @@ -0,0 +1,8 @@ +mod adam; +mod adapt; +mod dual_avg; + +pub use adam::AdamOptions; +pub(crate) use adapt::Strategy; +pub use adapt::{StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings}; +pub(crate) use dual_avg::AcceptanceRateCollector; diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs deleted file mode 100644 index f8323e9..0000000 --- a/src/stepsize_adapt.rs +++ /dev/null @@ -1,239 +0,0 @@ -use arrow::{ - array::{ArrayBuilder, PrimitiveBuilder, StructArray}, - datatypes::{DataType, Field, Float64Type, UInt64Type}, -}; -use rand::Rng; - -use crate::{ - hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, - nuts::{Collector, NutsOptions}, - sampler_stats::{SamplerStats, StatTraceBuilder}, - stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions}, - Math, NutsError, Settings, -}; - -pub struct Strategy { - step_size_adapt: DualAverage, - options: DualAverageSettings, - pub last_mean_tree_accept: f64, - pub last_sym_mean_tree_accept: f64, - pub last_n_steps: u64, -} - -impl Strategy { - pub fn new(options: DualAverageSettings) -> Self { - Self { - options, - step_size_adapt: DualAverage::new(options.params, options.initial_step), - last_n_steps: 0, - last_sym_mean_tree_accept: 0.0, - last_mean_tree_accept: 0.0, - } - } - - pub fn init>( - &mut self, - math: &mut M, - options: &mut NutsOptions, - hamiltonian: &mut impl Hamiltonian, - position: &[f64], - rng: &mut R, - ) -> Result<(), NutsError> { - let mut state = hamiltonian.init_state(math, position)?; - hamiltonian.initialize_trajectory(math, &mut state, rng)?; - - let mut collector = AcceptanceRateCollector::new(); - - collector.register_init(math, &state, options); - - *hamiltonian.step_size_mut() = self.options.initial_step; - - let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); - - let LeapfrogResult::Ok(_) = state_next else { - return Ok(()); - }; - - let accept_stat = collector.mean.current(); - let dir = if accept_stat > self.options.target_accept { - Direction::Forward - } else { - Direction::Backward - }; - - for _ in 0..100 { - let mut collector = AcceptanceRateCollector::new(); - collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); - let LeapfrogResult::Ok(_) = state_next else { - *hamiltonian.step_size_mut() = self.options.initial_step; - return Ok(()); - }; - let accept_stat = collector.mean.current(); - match dir { - Direction::Forward => { - if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5) - { - self.step_size_adapt = - DualAverage::new(self.options.params, hamiltonian.step_size()); - return Ok(()); - } - *hamiltonian.step_size_mut() *= 2.; - } - Direction::Backward => { - if (accept_stat >= self.options.target_accept) - | (hamiltonian.step_size() < 1e-10) - { - self.step_size_adapt = - DualAverage::new(self.options.params, hamiltonian.step_size()); - return Ok(()); - } - *hamiltonian.step_size_mut() /= 2.; - } - } - } - // If we don't find something better, use the specified initial value - *hamiltonian.step_size_mut() = self.options.initial_step; - Ok(()) - } - - pub fn update(&mut self, collector: &AcceptanceRateCollector) { - let mean_sym = collector.mean_sym.current(); - let mean = collector.mean.current(); - let n_steps = collector.mean.count(); - self.last_mean_tree_accept = mean; - self.last_sym_mean_tree_accept = mean_sym; - self.last_n_steps = n_steps; - } - - pub fn update_estimator_early(&mut self) { - self.step_size_adapt - .advance(self.last_mean_tree_accept, self.options.target_accept); - } - - pub fn update_estimator_late(&mut self) { - self.step_size_adapt - .advance(self.last_sym_mean_tree_accept, self.options.target_accept); - } - - pub fn update_stepsize( - &mut self, - potential: &mut impl Hamiltonian, - use_best_guess: bool, - ) { - if use_best_guess { - *potential.step_size_mut() = self.step_size_adapt.current_step_size_adapted(); - } else { - *potential.step_size_mut() = self.step_size_adapt.current_step_size(); - } - } - - pub fn new_collector(&self) -> AcceptanceRateCollector { - AcceptanceRateCollector::new() - } -} - -pub struct StatsBuilder { - step_size_bar: PrimitiveBuilder, - mean_tree_accept: PrimitiveBuilder, - mean_tree_accept_sym: PrimitiveBuilder, - n_steps: PrimitiveBuilder, -} - -impl StatTraceBuilder for StatsBuilder { - fn append_value(&mut self, _math: Option<&mut M>, value: &Strategy) { - self.step_size_bar - .append_value(value.step_size_adapt.current_step_size_adapted()); - self.mean_tree_accept - .append_value(value.last_mean_tree_accept); - self.mean_tree_accept_sym - .append_value(value.last_sym_mean_tree_accept); - self.n_steps.append_value(value.last_n_steps); - } - - fn finalize(self) -> Option { - let Self { - mut step_size_bar, - mut mean_tree_accept, - mut mean_tree_accept_sym, - mut n_steps, - } = self; - - let fields = vec![ - Field::new("step_size_bar", DataType::Float64, false), - Field::new("mean_tree_accept", DataType::Float64, false), - Field::new("mean_tree_accept_sym", DataType::Float64, false), - Field::new("n_steps", DataType::UInt64, false), - ]; - - let arrays = vec![ - ArrayBuilder::finish(&mut step_size_bar), - ArrayBuilder::finish(&mut mean_tree_accept), - ArrayBuilder::finish(&mut mean_tree_accept_sym), - ArrayBuilder::finish(&mut n_steps), - ]; - - Some(StructArray::new(fields.into(), arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { - step_size_bar, - mean_tree_accept, - mean_tree_accept_sym, - n_steps, - } = self; - - let fields = vec![ - Field::new("step_size_bar", DataType::Float64, false), - Field::new("mean_tree_accept", DataType::Float64, false), - Field::new("mean_tree_accept_sym", DataType::Float64, false), - Field::new("n_steps", DataType::UInt64, false), - ]; - - let arrays = vec![ - ArrayBuilder::finish_cloned(step_size_bar), - ArrayBuilder::finish_cloned(mean_tree_accept), - ArrayBuilder::finish_cloned(mean_tree_accept_sym), - ArrayBuilder::finish_cloned(n_steps), - ]; - - Some(StructArray::new(fields.into(), arrays, None)) - } -} - -impl SamplerStats for Strategy { - type Builder = StatsBuilder; - type StatOptions = (); - - fn new_builder( - &self, - _stat_options: Self::StatOptions, - _settings: &impl Settings, - _dim: usize, - ) -> Self::Builder { - Self::Builder { - step_size_bar: PrimitiveBuilder::new(), - mean_tree_accept: PrimitiveBuilder::new(), - mean_tree_accept_sym: PrimitiveBuilder::new(), - n_steps: PrimitiveBuilder::new(), - } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct DualAverageSettings { - pub target_accept: f64, - pub initial_step: f64, - pub params: DualAverageOptions, -} - -impl Default for DualAverageSettings { - fn default() -> Self { - Self { - target_accept: 0.8, - initial_step: 0.1, - params: DualAverageOptions::default(), - } - } -} diff --git a/src/storage/csv.rs b/src/storage/csv.rs new file mode 100644 index 0000000..3773223 --- /dev/null +++ b/src/storage/csv.rs @@ -0,0 +1,966 @@ +//! CSV storage backend for nuts-rs that outputs CmdStan-compatible CSV files +//! +//! This module provides a CSV storage backend that saves MCMC samples and +//! statistics in a format compatible with CmdStan, allowing existing Stan +//! analysis tools and libraries to read nuts-rs results. + +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufWriter, Write}; +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result}; +use nuts_storable::{ItemType, Value}; + +use crate::storage::{ChainStorage, StorageConfig, TraceStorage}; +use crate::{Math, Progress, Settings}; + +/// Configuration for CSV-based MCMC storage. +/// +/// This storage backend creates Stan-compatible CSV files with one file per chain. +/// Files are named `chain_{id}.csv` where `{id}` is the chain number starting from 0. +/// +/// The CSV format matches CmdStan output: +/// - Header row with column names +/// - Warmup samples with negative sample_id +/// - Post-warmup samples with positive sample_id +/// - Standard Stan statistics (lp__, stepsize, treedepth, etc.) +/// - Parameter columns +pub struct CsvConfig { + /// Directory where CSV files will be written + output_dir: PathBuf, + /// Number of decimal places for floating point values + precision: usize, + /// Whether to store warmup samples (default: true) + store_warmup: bool, +} + +impl CsvConfig { + /// Create a new CSV configuration. + /// + /// # Arguments + /// + /// * `output_dir` - Directory where CSV files will be written + /// + /// # Example + /// + /// ```rust + /// use nuts_rs::CsvConfig; + /// let config = CsvConfig::new("mcmc_output"); + /// ``` + pub fn new>(output_dir: P) -> Self { + Self { + output_dir: output_dir.as_ref().to_path_buf(), + precision: 6, + store_warmup: true, + } + } + + /// Set the precision (number of decimal places) for floating point values. + /// + /// Default is 6 decimal places. + pub fn with_precision(mut self, precision: usize) -> Self { + self.precision = precision; + self + } + + /// Configure whether to store warmup samples. + /// + /// When true (default), warmup samples are included with negative sample IDs. + /// When false, only post-warmup samples are stored. + pub fn store_warmup(mut self, store: bool) -> Self { + self.store_warmup = store; + self + } +} + +/// Main CSV storage managing multiple chains +pub struct CsvTraceStorage { + output_dir: PathBuf, + precision: usize, + store_warmup: bool, + parameter_names: Vec, + column_mapping: Vec<(String, usize)>, // (data_name, index_in_data) +} + +/// Per-chain CSV storage +pub struct CsvChainStorage { + writer: BufWriter, + precision: usize, + store_warmup: bool, + parameter_names: Vec, + column_mapping: Vec<(String, usize)>, // (data_name, index_in_data) + is_first_sample: bool, + headers_written: bool, +} + +impl CsvChainStorage { + /// Create a new CSV chain storage + fn new( + output_dir: &Path, + chain_id: u64, + precision: usize, + store_warmup: bool, + parameter_names: Vec, + column_mapping: Vec<(String, usize)>, + ) -> Result { + std::fs::create_dir_all(output_dir) + .with_context(|| format!("Failed to create output directory: {:?}", output_dir))?; + + let file_path = output_dir.join(format!("chain_{}.csv", chain_id)); + let file = File::create(&file_path) + .with_context(|| format!("Failed to create CSV file: {:?}", file_path))?; + let writer = BufWriter::new(file); + + Ok(Self { + writer, + precision, + store_warmup, + parameter_names, + column_mapping, + is_first_sample: true, + headers_written: false, + }) + } + + /// Write the CSV header row + fn write_header(&mut self) -> Result<()> { + if self.headers_written { + return Ok(()); + } + + // Standard CmdStan header format - only the core columns + let mut headers = vec![ + "lp__".to_string(), + "accept_stat__".to_string(), + "stepsize__".to_string(), + "treedepth__".to_string(), + "n_leapfrog__".to_string(), + "divergent__".to_string(), + "energy__".to_string(), + ]; + + // Add parameter columns from the expanded parameter vector + for param_name in &self.parameter_names { + headers.push(param_name.clone()); + } + + // Write header row + writeln!(self.writer, "{}", headers.join(","))?; + self.headers_written = true; + Ok(()) + } + + /// Format a value for CSV output + fn format_value(&self, value: &Value) -> String { + match value { + Value::ScalarF64(v) => { + if v.is_nan() { + "NA".to_string() + } else if v.is_infinite() { + if *v > 0.0 { "Inf" } else { "-Inf" }.to_string() + } else { + format!("{:.prec$}", v, prec = self.precision) + } + } + Value::ScalarF32(v) => { + if v.is_nan() { + "NA".to_string() + } else if v.is_infinite() { + if *v > 0.0 { "Inf" } else { "-Inf" }.to_string() + } else { + format!("{:.prec$}", v, prec = self.precision) + } + } + Value::ScalarU64(v) => v.to_string(), + Value::ScalarI64(v) => v.to_string(), + Value::ScalarBool(v) => if *v { "1" } else { "0" }.to_string(), + Value::F64(vec) => { + // For vector values, we'll just use the first element for now + // A more sophisticated implementation would handle multi-dimensional parameters + if vec.is_empty() { + "NA".to_string() + } else { + self.format_value(&Value::ScalarF64(vec[0])) + } + } + Value::F32(vec) => { + if vec.is_empty() { + "NA".to_string() + } else { + self.format_value(&Value::ScalarF32(vec[0])) + } + } + Value::U64(vec) => { + if vec.is_empty() { + "NA".to_string() + } else { + vec[0].to_string() + } + } + Value::I64(vec) => { + if vec.is_empty() { + "NA".to_string() + } else { + vec[0].to_string() + } + } + Value::Bool(vec) => { + if vec.is_empty() { + "NA".to_string() + } else { + if vec[0] { "1" } else { "0" }.to_string() + } + } + Value::ScalarString(v) => v.clone(), + Value::Strings(vec) => { + if vec.is_empty() { + "NA".to_string() + } else { + vec[0].clone() + } + } + } + } + + /// Write a single sample row to the CSV file + fn write_sample_row( + &mut self, + stats: &Vec<(&str, Option)>, + draws: &Vec<(&str, Option)>, + _info: &Progress, + ) -> Result<()> { + let mut row_values = Vec::new(); + + // Create lookup maps for quick access + let stats_map: HashMap<&str, &Option> = stats.iter().map(|(k, v)| (*k, v)).collect(); + let draws_map: HashMap<&str, &Option> = draws.iter().map(|(k, v)| (*k, v)).collect(); + + // Helper function to get stat value + let get_stat_value = |name: &str| -> String { + stats_map + .get(name) + .and_then(|opt| opt.as_ref()) + .map(|v| self.format_value(v)) + .unwrap_or_else(|| "NA".to_string()) + }; + + row_values.push(get_stat_value("logp")); + row_values.push(get_stat_value("mean_tree_accept")); + row_values.push(get_stat_value("step_size")); + row_values.push(get_stat_value("depth")); + row_values.push(get_stat_value("n_steps")); + let divergent_val = stats_map + .get("diverging") + .and_then(|opt| opt.as_ref()) + .map(|v| match v { + Value::ScalarBool(true) => "1".to_string(), + Value::ScalarBool(false) => "0".to_string(), + _ => "0".to_string(), + }) + .unwrap_or_else(|| "0".to_string()); + row_values.push(divergent_val); + + row_values.push(get_stat_value("energy")); + + // Add parameter values using the column mapping + for (_param_name, (data_name, index)) in + self.parameter_names.iter().zip(&self.column_mapping) + { + if let Some(Some(data_value)) = draws_map.get(data_name.as_str()) { + let formatted_value = match data_value { + Value::F64(vec) => { + if *index < vec.len() { + self.format_value(&Value::ScalarF64(vec[*index])) + } else { + "NA".to_string() + } + } + Value::F32(vec) => { + if *index < vec.len() { + self.format_value(&Value::ScalarF32(vec[*index])) + } else { + "NA".to_string() + } + } + Value::I64(vec) => { + if *index < vec.len() { + self.format_value(&Value::ScalarI64(vec[*index])) + } else { + "NA".to_string() + } + } + Value::U64(vec) => { + if *index < vec.len() { + self.format_value(&Value::ScalarU64(vec[*index])) + } else { + "NA".to_string() + } + } + // Handle scalar values (index should be 0) + scalar_val if *index == 0 => self.format_value(scalar_val), + _ => "NA".to_string(), + }; + row_values.push(formatted_value); + } else { + row_values.push("NA".to_string()); + } + } + + // Write the row + writeln!(self.writer, "{}", row_values.join(","))?; + Ok(()) + } +} + +impl ChainStorage for CsvChainStorage { + type Finalized = (); + + fn record_sample( + &mut self, + _settings: &impl Settings, + stats: Vec<(&str, Option)>, + draws: Vec<(&str, Option)>, + info: &Progress, + ) -> Result<()> { + // Skip warmup samples if not storing them + if info.tuning && !self.store_warmup { + return Ok(()); + } + + // Write header on first sample + if self.is_first_sample { + self.write_header()?; + self.is_first_sample = false; + } + + self.write_sample_row(&stats, &draws, info)?; + Ok(()) + } + + fn finalize(mut self) -> Result { + self.writer.flush().context("Failed to flush CSV file")?; + Ok(()) + } + + fn flush(&self) -> Result<()> { + // BufWriter doesn't provide a way to flush without mutable reference + // In practice, the buffer will be flushed when the file is closed + Ok(()) + } +} + +impl StorageConfig for CsvConfig { + type Storage = CsvTraceStorage; + + fn new_trace(self, settings: &impl Settings, math: &M) -> Result { + // Generate parameter names and column mapping using coordinates + let (parameter_names, column_mapping) = + generate_parameter_names_and_mapping(settings, math)?; + + Ok(CsvTraceStorage { + output_dir: self.output_dir, + precision: self.precision, + store_warmup: self.store_warmup, + parameter_names, + column_mapping, + }) + } +} + +/// Generate parameter column names and mapping using coordinates or Stan-compliant indexing +fn generate_parameter_names_and_mapping( + settings: &impl Settings, + math: &M, +) -> Result<(Vec, Vec<(String, usize)>)> { + let data_dims = settings.data_dims_all(math); + let coords = math.coords(); + let mut parameter_names = Vec::new(); + let mut column_mapping = Vec::new(); + + for (var_name, var_dims) in data_dims { + let data_type = settings.data_type(math, &var_name); + + // Only process vector types that could contain parameter values + if matches!( + data_type, + ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64 + ) { + let (column_names, indices) = generate_column_names_and_indices_for_variable( + &var_name, &var_dims, &coords, math, + )?; + + for (name, index) in column_names.into_iter().zip(indices) { + parameter_names.push(name); + column_mapping.push((var_name.clone(), index)); + } + } + } + + // If no parameter names were generated, fall back to simple numbering + if parameter_names.is_empty() { + let dim_sizes = math.dim_sizes(); + let param_count = dim_sizes.get("expanded_parameter").unwrap_or(&0); + for i in 0..*param_count { + parameter_names.push(format!("param_{}", i + 1)); + // Try to find a data field that contains the parameters + let data_names = settings.data_names(math); + let mut found_field = false; + for data_name in &data_names { + let data_type = settings.data_type(math, data_name); + if matches!( + data_type, + ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64 + ) { + column_mapping.push((data_name.clone(), i as usize)); + found_field = true; + break; + } + } + if !found_field { + column_mapping.push(("unknown".to_string(), i as usize)); + } + } + } + + Ok((parameter_names, column_mapping)) +} + +/// Generate column names and indices for a single variable using its dimensions and coordinates +fn generate_column_names_and_indices_for_variable( + var_name: &str, + var_dims: &[String], + coords: &HashMap, + math: &M, +) -> Result<(Vec, Vec)> { + let dim_sizes = math.dim_sizes(); + + if var_dims.is_empty() { + // Scalar variable + return Ok((vec![var_name.to_string()], vec![0])); + } + + // Check if we have meaningful coordinate names for all dimensions + let has_meaningful_coords = var_dims.iter().all(|dim_name| { + coords.get(dim_name).is_some_and( + |coord_value| matches!(coord_value, Value::Strings(labels) if !labels.is_empty()), + ) + }); + + // Get coordinate labels for each dimension + let mut dim_coords: Vec> = Vec::new(); + let mut dim_sizes_vec: Vec = Vec::new(); + + for dim_name in var_dims { + let size = *dim_sizes.get(dim_name).unwrap_or(&1) as usize; + dim_sizes_vec.push(size); + + if has_meaningful_coords { + // Use coordinate names if available and meaningful + if let Some(coord_value) = coords.get(dim_name) { + match coord_value { + Value::Strings(labels) => { + dim_coords.push(labels.clone()); + } + _ => { + // Fallback to 1-based indexing (Stan format) + dim_coords.push((1..=size).map(|i| i.to_string()).collect()); + } + } + } else { + // Fallback to 1-based indexing (Stan format) + dim_coords.push((1..=size).map(|i| i.to_string()).collect()); + } + } else { + // Use Stan-compliant 1-based indexing + dim_coords.push((1..=size).map(|i| i.to_string()).collect()); + } + } + + // Generate Cartesian product using column-major order (Stan format) + let (coord_names, indices) = + cartesian_product_with_indices_column_major(&dim_coords, &dim_sizes_vec); + + // Prepend variable name to each coordinate combination + let column_names: Vec = coord_names + .into_iter() + .map(|coord| format!("{}.{}", var_name, coord)) + .collect(); + + Ok((column_names, indices)) +} + +/// Compute the Cartesian product with column-major ordering (Stan format) +/// +/// Stan uses what they call "column-major" ordering, but it's actually the same as +/// row-major order: the first index changes slowest, last index changes fastest. +/// For example, a 2x3 array produces: [1,1], [1,2], [1,3], [2,1], [2,2], [2,3] +fn cartesian_product_with_indices_column_major( + coord_sets: &[Vec], + dim_sizes: &[usize], +) -> (Vec, Vec) { + if coord_sets.is_empty() { + return (vec![], vec![]); + } + + if coord_sets.len() == 1 { + let indices: Vec = (0..coord_sets[0].len()).collect(); + return (coord_sets[0].clone(), indices); + } + + let mut names = vec![]; + let mut indices = vec![]; + + // Stan's "column-major" is actually row-major order + cartesian_product_recursive_with_indices( + coord_sets, + dim_sizes, + 0, + &mut String::new(), + &mut vec![], + &mut names, + &mut indices, + ); + + (names, indices) +} + +fn cartesian_product_recursive_with_indices( + coord_sets: &[Vec], + dim_sizes: &[usize], + dim_idx: usize, + current_name: &mut String, + current_indices: &mut Vec, + result_names: &mut Vec, + result_indices: &mut Vec, +) { + if dim_idx == coord_sets.len() { + result_names.push(current_name.clone()); + // Compute linear index from multi-dimensional indices + let mut linear_index = 0; + for (i, &idx) in current_indices.iter().enumerate() { + let mut stride = 1; + for &size in &dim_sizes[i + 1..] { + stride *= size; + } + linear_index += idx * stride; + } + result_indices.push(linear_index); + return; + } + + let is_first_dim = dim_idx == 0; + + for (coord_idx, coord) in coord_sets[dim_idx].iter().enumerate() { + let mut new_name = current_name.clone(); + if !is_first_dim { + new_name.push('.'); + } + new_name.push_str(coord); + + current_indices.push(coord_idx); + cartesian_product_recursive_with_indices( + coord_sets, + dim_sizes, + dim_idx + 1, + &mut new_name, + current_indices, + result_names, + result_indices, + ); + current_indices.pop(); + } +} + +impl TraceStorage for CsvTraceStorage { + type ChainStorage = CsvChainStorage; + type Finalized = (); + + fn initialize_trace_for_chain(&self, chain_id: u64) -> Result { + CsvChainStorage::new( + &self.output_dir, + chain_id, + self.precision, + self.store_warmup, + self.parameter_names.clone(), + self.column_mapping.clone(), + ) + } + + fn finalize( + self, + traces: Vec::Finalized>>, + ) -> Result<(Option, Self::Finalized)> { + // Check for any errors in the chain finalizations + for trace_result in traces { + if let Err(err) = trace_result { + return Ok((Some(err), ())); + } + } + Ok((None, ())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler, + }; + use anyhow::Result; + use nuts_derive::Storable; + use nuts_storable::{HasDims, Value}; + use rand::Rng; + use std::collections::HashMap; + use std::fs; + use std::path::Path; + use thiserror::Error; + + #[allow(dead_code)] + #[derive(Debug, Error)] + enum TestLogpError { + #[error("Test error")] + Test, + } + + impl LogpError for TestLogpError { + fn is_recoverable(&self) -> bool { + false + } + } + + /// Test model with multi-dimensional coordinates + #[derive(Clone)] + struct MultiDimTestLogp { + dim_a: usize, + dim_b: usize, + } + + impl HasDims for MultiDimTestLogp { + fn dim_sizes(&self) -> HashMap { + HashMap::from([ + ("a".to_string(), self.dim_a as u64), + ("b".to_string(), self.dim_b as u64), + ]) + } + + fn coords(&self) -> HashMap { + HashMap::from([ + ( + "a".to_string(), + Value::Strings(vec!["x".to_string(), "y".to_string()]), + ), + ( + "b".to_string(), + Value::Strings(vec!["alpha".to_string(), "beta".to_string()]), + ), + ]) + } + } + + #[derive(Storable)] + struct MultiDimExpandedDraw { + #[storable(dims("a", "b"))] + param_matrix: Vec, + scalar_value: f64, + } + + impl CpuLogpFunc for MultiDimTestLogp { + type LogpError = TestLogpError; + type FlowParameters = (); + type ExpandedVector = MultiDimExpandedDraw; + + fn dim(&self) -> usize { + self.dim_a * self.dim_b + } + + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let mut logp = 0.0; + for (i, &xi) in x.iter().enumerate() { + logp -= 0.5 * xi * xi; + grad[i] = -xi; + } + Ok(logp) + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + Ok(MultiDimExpandedDraw { + param_matrix: array.to_vec(), + scalar_value: array.iter().sum(), + }) + } + + fn vector_coord(&self) -> Option { + Some(Value::Strings( + (0..self.dim()).map(|i| format!("theta{}", i + 1)).collect(), + )) + } + } + + struct MultiDimTestModel { + math: CpuMath, + } + + impl Model for MultiDimTestModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + for p in position.iter_mut() { + *p = rng.random_range(-1.0..1.0); + } + Ok(()) + } + } + + /// Test model without coordinates (fallback behavior) + #[derive(Clone)] + struct SimpleTestLogp { + dim: usize, + } + + impl HasDims for SimpleTestLogp { + fn dim_sizes(&self) -> HashMap { + HashMap::from([("simple_param".to_string(), self.dim as u64)]) + } + // No coords() method - should use fallback + } + + #[derive(Storable)] + struct SimpleExpandedDraw { + #[storable(dims("simple_param"))] + values: Vec, + } + + impl CpuLogpFunc for SimpleTestLogp { + type LogpError = TestLogpError; + type FlowParameters = (); + type ExpandedVector = SimpleExpandedDraw; + + fn dim(&self) -> usize { + self.dim + } + + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let mut logp = 0.0; + for (i, &xi) in x.iter().enumerate() { + logp -= 0.5 * xi * xi; + grad[i] = -xi; + } + Ok(logp) + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + Ok(SimpleExpandedDraw { + values: array.to_vec(), + }) + } + + fn vector_coord(&self) -> Option { + Some(Value::Strings(vec![ + "param1".to_string(), + "param2".to_string(), + "param3".to_string(), + ])) + } + } + + struct SimpleTestModel { + math: CpuMath, + } + + impl Model for SimpleTestModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + for p in position.iter_mut() { + *p = rng.random_range(-1.0..1.0); + } + Ok(()) + } + } + + fn read_csv_header(path: &Path) -> Result { + let content = fs::read_to_string(path)?; + content + .lines() + .next() + .map(|s| s.to_string()) + .ok_or_else(|| anyhow::anyhow!("Empty CSV file")) + } + + #[test] + fn test_multidim_coordinate_naming() -> Result<()> { + let temp_dir = tempfile::tempdir()?; + let output_path = temp_dir.path().join("multidim_test"); + + // Create model with 2x2 parameter matrix + let model = MultiDimTestModel { + math: CpuMath::new(MultiDimTestLogp { dim_a: 2, dim_b: 2 }), + }; + + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = 1; + settings.num_tune = 10; + settings.num_draws = 20; + settings.seed = 42; + + let csv_config = CsvConfig::new(&output_path) + .with_precision(6) + .store_warmup(false); + + let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?); + + // Wait for sampling to complete + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(std::time::Duration::from_millis(100)) { + crate::SamplerWaitResult::Trace(_) => break, + crate::SamplerWaitResult::Timeout(s) => sampler = Some(s), + crate::SamplerWaitResult::Err(err, _) => return Err(err), + } + } + + // Check that CSV file was created + let csv_file = output_path.join("chain_0.csv"); + assert!(csv_file.exists()); + + // Check header contains expected coordinate names + let header = read_csv_header(&csv_file)?; + + // Should contain Cartesian product: x.alpha, x.beta, y.alpha, y.beta + assert!(header.contains("param_matrix.x.alpha")); + assert!(header.contains("param_matrix.x.beta")); + assert!(header.contains("param_matrix.y.alpha")); + assert!(header.contains("param_matrix.y.beta")); + assert!(header.contains("scalar_value")); + + // Verify column order (Cartesian product should be in correct order) + let columns: Vec<&str> = header.split(',').collect(); + let param_columns: Vec<&str> = columns + .iter() + .filter(|col| col.starts_with("param_matrix.")) + .cloned() + .collect(); + + assert_eq!( + param_columns, + vec![ + "param_matrix.x.alpha", + "param_matrix.x.beta", + "param_matrix.y.alpha", + "param_matrix.y.beta" + ] + ); + + Ok(()) + } + + #[test] + fn test_fallback_coordinate_naming() -> Result<()> { + let temp_dir = tempfile::tempdir()?; + let output_path = temp_dir.path().join("simple_test"); + + // Create model with 3 parameters but no coordinate specification + let model = SimpleTestModel { + math: CpuMath::new(SimpleTestLogp { dim: 3 }), + }; + + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = 1; + settings.num_tune = 5; + settings.num_draws = 10; + settings.seed = 123; + + let csv_config = CsvConfig::new(&output_path) + .with_precision(6) + .store_warmup(false); + + let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?); + + // Wait for sampling to complete + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(std::time::Duration::from_millis(100)) { + crate::SamplerWaitResult::Trace(_) => break, + crate::SamplerWaitResult::Timeout(s) => sampler = Some(s), + crate::SamplerWaitResult::Err(err, _) => return Err(err), + } + } + + // Check that CSV file was created + let csv_file = output_path.join("chain_0.csv"); + assert!(csv_file.exists()); + + // Check header uses fallback numeric naming + let header = read_csv_header(&csv_file)?; + + // Should fall back to 1-based indices since no coordinates provided + assert!(header.contains("values.1")); + assert!(header.contains("values.2")); + assert!(header.contains("values.3")); + + Ok(()) + } + + #[test] + fn test_cartesian_product_generation() { + let coord_sets = vec![ + vec!["x".to_string(), "y".to_string()], + vec!["alpha".to_string(), "beta".to_string()], + ]; + let dim_sizes = vec![2, 2]; + + let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes); + + assert_eq!(names, vec!["x.alpha", "x.beta", "y.alpha", "y.beta"]); + + assert_eq!(indices, vec![0, 1, 2, 3]); + } + + #[test] + fn test_single_dimension_coordinates() { + let coord_sets = vec![vec!["param1".to_string(), "param2".to_string()]]; + let dim_sizes = vec![2]; + + let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes); + + assert_eq!(names, vec!["param1", "param2"]); + assert_eq!(indices, vec![0, 1]); + } + + #[test] + fn test_three_dimension_cartesian_product() { + let coord_sets = vec![ + vec!["a".to_string(), "b".to_string()], + vec!["1".to_string()], + vec!["i".to_string(), "j".to_string()], + ]; + let dim_sizes = vec![2, 1, 2]; + + let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes); + + assert_eq!(names, vec!["a.1.i", "a.1.j", "b.1.i", "b.1.j"]); + + assert_eq!(indices, vec![0, 1, 2, 3]); + } +} diff --git a/src/storage/hashmap.rs b/src/storage/hashmap.rs new file mode 100644 index 0000000..a5c6e2b --- /dev/null +++ b/src/storage/hashmap.rs @@ -0,0 +1,317 @@ +use anyhow::Result; +use nuts_storable::{ItemType, Value}; +use std::collections::HashMap; + +use crate::storage::{ChainStorage, StorageConfig, TraceStorage}; +use crate::{Progress, Settings}; + +/// Container for different types of sample values in HashMaps +#[derive(Clone, Debug)] +pub enum HashMapValue { + F64(Vec), + F32(Vec), + Bool(Vec), + I64(Vec), + U64(Vec), + String(Vec), +} + +impl HashMapValue { + /// Create a new empty HashMapValue of the specified type + fn new(item_type: ItemType) -> Self { + match item_type { + ItemType::F64 => HashMapValue::F64(Vec::new()), + ItemType::F32 => HashMapValue::F32(Vec::new()), + ItemType::Bool => HashMapValue::Bool(Vec::new()), + ItemType::I64 => HashMapValue::I64(Vec::new()), + ItemType::U64 => HashMapValue::U64(Vec::new()), + ItemType::String => HashMapValue::String(Vec::new()), + } + } + + /// Push a value to the internal vector + fn push(&mut self, value: Value) { + match (self, value) { + // Scalar values - store as single element vectors for array types + (HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v), + (HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v), + (HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v), + (HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v), + (HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v), + + (HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v), + (HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v), + (HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v), + (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v), + (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v), + + _ => panic!("Mismatched item type"), + } + } +} + +/// Main storage for HashMap MCMC traces +pub struct HashMapTraceStorage { + draw_types: Vec<(String, ItemType)>, + param_types: Vec<(String, ItemType)>, +} + +/// Per-chain storage for HashMap MCMC traces +pub struct HashMapChainStorage { + warmup_stats: HashMap, + sample_stats: HashMap, + warmup_draws: HashMap, + sample_draws: HashMap, + last_sample_was_warmup: bool, +} + +/// Final result containing the collected samples +#[derive(Debug, Clone)] +pub struct HashMapResult { + /// HashMap containing sampler stats including warmup samples + pub stats: HashMap, + /// HashMap containing draws including warmup samples + pub draws: HashMap, +} + +impl HashMapChainStorage { + /// Create a new chain storage with HashMaps for parameters and samples + fn new(param_types: &Vec<(String, ItemType)>, draw_types: &Vec<(String, ItemType)>) -> Self { + let warmup_stats = param_types + .iter() + .cloned() + .map(|(name, item_type)| (name, HashMapValue::new(item_type))) + .collect(); + + let sample_stats = param_types + .iter() + .cloned() + .map(|(name, item_type)| (name, HashMapValue::new(item_type))) + .collect(); + + let warmup_draws = draw_types + .iter() + .cloned() + .map(|(name, item_type)| (name, HashMapValue::new(item_type))) + .collect(); + + let sample_draws = draw_types + .iter() + .cloned() + .map(|(name, item_type)| (name, HashMapValue::new(item_type))) + .collect(); + + Self { + warmup_stats, + sample_stats, + warmup_draws, + sample_draws, + last_sample_was_warmup: true, + } + } + + /// Store a parameter value + fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + + let target_map = if is_warmup { + &mut self.warmup_stats + } else { + &mut self.sample_stats + }; + + if let Some(hash_value) = target_map.get_mut(name) { + hash_value.push(value); + } else { + panic!("Unknown param name: {}", name); + } + Ok(()) + } + + /// Store a draw value + fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + + let target_map = if is_warmup { + &mut self.warmup_draws + } else { + &mut self.sample_draws + }; + + if let Some(hash_value) = target_map.get_mut(name) { + hash_value.push(value); + } else { + panic!("Unknown posterior variable name: {}", name); + } + Ok(()) + } +} + +impl ChainStorage for HashMapChainStorage { + type Finalized = HashMapResult; + + fn record_sample( + &mut self, + _settings: &impl Settings, + stats: Vec<(&str, Option)>, + draws: Vec<(&str, Option)>, + info: &Progress, + ) -> Result<()> { + let is_first_draw = self.last_sample_was_warmup && !info.tuning; + if is_first_draw { + self.last_sample_was_warmup = false; + } + + for (name, value) in stats { + if let Some(value) = value { + self.push_param(name, value, info.tuning)?; + } + } + for (name, value) in draws { + if let Some(value) = value { + self.push_draw(name, value, info.tuning)?; + } else { + panic!("Missing draw value for {}", name); + } + } + Ok(()) + } + + /// Finalize storage and return the collected samples + fn finalize(self) -> Result { + // Combine warmup and sample data + let mut combined_stats = HashMap::new(); + let mut combined_draws = HashMap::new(); + + // Combine stats + for (key, warmup_values) in self.warmup_stats { + let sample_values = &self.sample_stats[&key]; + let mut combined = warmup_values.clone(); + + match (&mut combined, sample_values) { + (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + _ => panic!("Type mismatch when combining stats for {}", key), + } + + combined_stats.insert(key, combined); + } + + // Combine draws + for (key, warmup_values) in self.warmup_draws { + let sample_values = &self.sample_draws[&key]; + let mut combined = warmup_values.clone(); + + match (&mut combined, sample_values) { + (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => { + combined_vec.extend(sample_vec.iter().cloned()); + } + _ => panic!("Type mismatch when combining draws for {}", key), + } + + combined_draws.insert(key, combined); + } + + Ok(HashMapResult { + stats: combined_stats, + draws: combined_draws, + }) + } + + /// Flush - no-op for HashMap storage since everything is in memory + fn flush(&self) -> Result<()> { + Ok(()) + } +} + +pub struct HashMapConfig {} + +impl Default for HashMapConfig { + fn default() -> Self { + Self::new() + } +} + +impl HashMapConfig { + pub fn new() -> Self { + Self {} + } +} + +impl StorageConfig for HashMapConfig { + type Storage = HashMapTraceStorage; + + fn new_trace( + self, + settings: &impl Settings, + math: &M, + ) -> Result { + Ok(HashMapTraceStorage { + param_types: settings.stat_types(math), + draw_types: settings.data_types(math), + }) + } +} + +impl TraceStorage for HashMapTraceStorage { + type ChainStorage = HashMapChainStorage; + + type Finalized = Vec; + + fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result { + Ok(HashMapChainStorage::new( + &self.param_types, + &self.draw_types, + )) + } + + fn finalize( + self, + traces: Vec::Finalized>>, + ) -> Result<(Option, Self::Finalized)> { + let mut results = Vec::new(); + let mut first_error = None; + + for trace in traces { + match trace { + Ok(result) => results.push(result), + Err(err) => { + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + + Ok((first_error, results)) + } +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..d8c04c3 --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1,17 @@ +mod csv; +mod hashmap; +#[cfg(feature = "ndarray")] +mod ndarray; +mod storage; +#[cfg(feature = "zarr")] +mod zarr; + +#[cfg(feature = "zarr")] +pub use zarr::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage}; + +pub use csv::{CsvConfig, CsvTraceStorage}; +pub use hashmap::{HashMapConfig, HashMapValue}; +#[cfg(feature = "ndarray")] +pub use ndarray::{NdarrayConfig, NdarrayTrace, NdarrayValue}; + +pub use storage::{ChainStorage, StorageConfig, TraceStorage}; diff --git a/src/storage/ndarray.rs b/src/storage/ndarray.rs new file mode 100644 index 0000000..c02dfff --- /dev/null +++ b/src/storage/ndarray.rs @@ -0,0 +1,345 @@ +use anyhow::{Context, Result}; +use ndarray::{ArrayD, IxDyn}; +use nuts_storable::{ItemType, Value}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use crate::storage::{ChainStorage, StorageConfig, TraceStorage}; +use crate::{Math, Progress, Settings}; + +/// Container for different types of ndarray values +#[derive(Debug, Clone)] +pub enum NdarrayValue { + F64(ArrayD), + F32(ArrayD), + Bool(ArrayD), + I64(ArrayD), + U64(ArrayD), + String(ArrayD), +} + +impl NdarrayValue { + /// Create a new ndarray with the specified type and shape + fn new(item_type: ItemType, shape: &[usize]) -> Self { + match item_type { + ItemType::F64 => NdarrayValue::F64(ArrayD::zeros(IxDyn(shape))), + ItemType::F32 => NdarrayValue::F32(ArrayD::zeros(IxDyn(shape))), + ItemType::Bool => NdarrayValue::Bool(ArrayD::from_elem(IxDyn(shape), false)), + ItemType::I64 => NdarrayValue::I64(ArrayD::zeros(IxDyn(shape))), + ItemType::U64 => NdarrayValue::U64(ArrayD::zeros(IxDyn(shape))), + ItemType::String => { + NdarrayValue::String(ArrayD::from_elem(IxDyn(shape), String::new())) + } + } + } + + /// Set values at the specified indices + fn set_value(&mut self, indices: &[usize], value: Value) -> Result<()> { + match (self, value) { + (NdarrayValue::F64(arr), Value::ScalarF64(v)) => { + arr[IxDyn(indices)] = v; + } + (NdarrayValue::F32(arr), Value::ScalarF32(v)) => { + arr[IxDyn(indices)] = v; + } + (NdarrayValue::Bool(arr), Value::ScalarBool(v)) => { + arr[IxDyn(indices)] = v; + } + (NdarrayValue::I64(arr), Value::ScalarI64(v)) => { + arr[IxDyn(indices)] = v; + } + (NdarrayValue::U64(arr), Value::ScalarU64(v)) => { + arr[IxDyn(indices)] = v; + } + (NdarrayValue::F64(arr), Value::F64(v)) => { + // For vector values, we need to handle the extra dimensions + if indices.len() == 2 { + // Simple case: just set the slice + let mut view = arr.slice_mut(ndarray::s![indices[0], indices[1], ..]); + for (i, val) in v.iter().enumerate() { + view[i] = *val; + } + } else { + return Err(anyhow::anyhow!( + "Vector assignment with complex indices not implemented" + )); + } + } + (NdarrayValue::F32(arr), Value::F32(v)) => { + if indices.len() == 2 { + let mut view = arr.slice_mut(ndarray::s![indices[0], indices[1], ..]); + for (i, val) in v.iter().enumerate() { + view[i] = *val; + } + } else { + return Err(anyhow::anyhow!( + "Vector assignment with complex indices not implemented" + )); + } + } + (NdarrayValue::Bool(arr), Value::Bool(v)) => { + if indices.len() == 2 { + let mut view = arr.slice_mut(ndarray::s![indices[0], indices[1], ..]); + for (i, val) in v.iter().enumerate() { + view[i] = *val; + } + } else { + return Err(anyhow::anyhow!( + "Vector assignment with complex indices not implemented" + )); + } + } + (NdarrayValue::I64(arr), Value::I64(v)) => { + if indices.len() == 2 { + let mut view = arr.slice_mut(ndarray::s![indices[0], indices[1], ..]); + for (i, val) in v.iter().enumerate() { + view[i] = *val; + } + } else { + return Err(anyhow::anyhow!( + "Vector assignment with complex indices not implemented" + )); + } + } + (NdarrayValue::U64(arr), Value::U64(v)) => { + if indices.len() == 2 { + let mut view = arr.slice_mut(ndarray::s![indices[0], indices[1], ..]); + for (i, val) in v.iter().enumerate() { + view[i] = *val; + } + } else { + return Err(anyhow::anyhow!( + "Vector assignment with complex indices not implemented" + )); + } + } + _ => return Err(anyhow::anyhow!("Mismatched item type")), + } + Ok(()) + } +} + +/// Final result containing the collected samples as ndarrays +#[derive(Debug, Clone)] +pub struct NdarrayTrace { + /// HashMap containing sampler stats as ndarrays with shape (n_chains, n_draws, *extra_dims) + pub stats: HashMap, + /// HashMap containing draws as ndarrays with shape (n_chains, n_draws, *extra_dims) + pub draws: HashMap, +} + +/// Shared storage container with interior mutability +#[derive(Clone)] +struct SharedArrays { + stats_arrays: HashMap, + draws_arrays: HashMap, +} + +/// Main storage for ndarray MCMC traces +pub struct NdarrayTraceStorage { + shared_arrays: Arc>, +} + +/// Per-chain storage for ndarray MCMC traces +pub struct NdarrayChainStorage { + shared_arrays: Arc>, + chain: usize, + current_draw: usize, +} + +impl NdarrayChainStorage { + /// Create a new chain storage + fn new(trace_storage: &NdarrayTraceStorage, chain: usize) -> Self { + Self { + shared_arrays: trace_storage.shared_arrays.clone(), + chain, + current_draw: 0, + } + } + + /// Store a parameter value in the ndarray + fn push_param(&mut self, name: &str, value: Value) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + + let mut shared = self.shared_arrays.lock().unwrap(); + if let Some(array) = shared.stats_arrays.get_mut(name) { + let indices = vec![self.chain, self.current_draw]; + array.set_value(&indices, value)?; + } else { + return Err(anyhow::anyhow!("Unknown param name: {}", name)); + } + Ok(()) + } + + /// Store a draw value in the ndarray + fn push_draw(&mut self, name: &str, value: Value) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + + let mut shared = self.shared_arrays.lock().unwrap(); + if let Some(array) = shared.draws_arrays.get_mut(name) { + let indices = vec![self.chain, self.current_draw]; + array.set_value(&indices, value)?; + } else { + return Err(anyhow::anyhow!("Unknown posterior variable name: {}", name)); + } + Ok(()) + } +} + +pub struct NdarrayConfig {} + +impl NdarrayConfig { + pub fn new() -> Self { + Self {} + } +} + +impl StorageConfig for NdarrayConfig { + type Storage = NdarrayTraceStorage; + + fn new_trace(self, settings: &impl Settings, math: &M) -> Result { + let n_chains = settings.num_chains(); + let n_tune = settings.hint_num_tune(); + let n_draws = settings.hint_num_draws(); + let total_draws = n_tune + n_draws; + + let mut stats_arrays = HashMap::new(); + let mut draws_arrays = HashMap::new(); + + let dim_sizes = math.dim_sizes(); + + // Create arrays for stats + for ((name, extra_dims), (name2, item_type)) in settings + .stat_dims_all(math) + .into_iter() + .zip(settings.stat_types(math).into_iter()) + { + assert!(name == name2); + if ["draw", "chain"].contains(&name.as_str()) { + continue; + } + + // Build shape: [n_chains, total_draws, ...extra_dims] + let mut shape = vec![n_chains, total_draws]; + for dim in extra_dims { + let dim_size = *dim_sizes + .get(&dim.to_string()) + .context(format!("Unknown dimension: {}", dim))? + as usize; + shape.push(dim_size); + } + + let array = NdarrayValue::new(item_type, &shape); + stats_arrays.insert(name, array); + } + + for ((name, extra_dims), (name2, item_type)) in settings + .stat_dims_all(math) + .into_iter() + .zip(settings.stat_types(math).into_iter()) + { + assert!(name == name2); + if ["draw", "chain"].contains(&name.as_str()) { + continue; + } + // Build shape: [n_chains, total_draws, ...extra_dims] + let mut shape = vec![n_chains, total_draws]; + for dim in extra_dims { + let dim_size = *dim_sizes + .get(&dim.to_string()) + .context(format!("Unknown dimension: {}", dim))? + as usize; + shape.push(dim_size); + } + + let array = NdarrayValue::new(item_type, &shape); + draws_arrays.insert(name, array); + } + + let shared_arrays = Arc::new(Mutex::new(SharedArrays { + stats_arrays, + draws_arrays, + })); + + Ok(NdarrayTraceStorage { shared_arrays }) + } +} + +impl ChainStorage for NdarrayChainStorage { + type Finalized = (); + + fn record_sample( + &mut self, + _settings: &impl Settings, + stats: Vec<(&str, Option)>, + draws: Vec<(&str, Option)>, + _info: &Progress, + ) -> Result<()> { + for (name, value) in stats { + if let Some(value) = value { + self.push_param(name, value)?; + } + } + for (name, value) in draws { + if let Some(value) = value { + self.push_draw(name, value)?; + } else { + return Err(anyhow::anyhow!("Missing draw value for {}", name)); + } + } + self.current_draw += 1; + Ok(()) + } + + /// Finalize storage - nothing to do for ndarray storage + fn finalize(self) -> Result { + Ok(()) + } + + /// Flush - no-op for ndarray storage since everything is in shared arrays + fn flush(&self) -> Result<()> { + Ok(()) + } +} + +impl TraceStorage for NdarrayTraceStorage { + type ChainStorage = NdarrayChainStorage; + + type Finalized = NdarrayTrace; + + fn initialize_trace_for_chain(&self, chain_id: u64) -> Result { + Ok(NdarrayChainStorage::new(self, chain_id as usize)) + } + + fn finalize( + self, + traces: Vec::Finalized>>, + ) -> Result<(Option, Self::Finalized)> { + let mut first_error = None; + + for trace in traces { + if let Err(err) = trace { + if first_error.is_none() { + first_error = Some(err); + } + } + } + + // Clone the arrays from the shared container since we can't move out of &self + let shared_arrays = self.shared_arrays.lock().unwrap(); + let stats_arrays = shared_arrays.stats_arrays.clone(); + let draws_arrays = shared_arrays.draws_arrays.clone(); + drop(shared_arrays); + + let result = NdarrayTrace { + stats: stats_arrays, + draws: draws_arrays, + }; + + Ok((first_error, result)) + } +} diff --git a/src/storage/storage.rs b/src/storage/storage.rs new file mode 100644 index 0000000..5032c1e --- /dev/null +++ b/src/storage/storage.rs @@ -0,0 +1,66 @@ +use anyhow::Result; +use nuts_storable::Value; + +use crate::{Math, Progress, Settings}; + +/// Trait for storing MCMC sampling results from a single chain. +/// +/// Handles progressive accumulation of statistics and draws during sampling, +/// with methods to record samples and finalize results. +pub trait ChainStorage: Send { + /// The type returned when the chain storage is finalized. + type Finalized: Send + Sync + 'static; + + /// Appends a new sample to the storage. + fn record_sample( + &mut self, + settings: &impl Settings, + stats: Vec<(&str, Option)>, + draws: Vec<(&str, Option)>, + info: &Progress, + ) -> Result<()>; + + /// Finalizes the storage and returns processed results. + fn finalize(self) -> Result; + + /// Flush any buffered data to ensure all samples are stored. + fn flush(&self) -> Result<()>; +} + +/// Configuration trait for creating MCMC storage backends. +/// +/// This is the main user-facing trait for configuring storage. Users choose +/// a storage backend by providing an implementation of this trait to the +/// sampling functions. +pub trait StorageConfig: Send + 'static { + /// The storage backend type this config creates. + type Storage: TraceStorage; + + /// Creates a new storage backend instance. + fn new_trace(self, settings: &impl Settings, math: &M) -> Result; +} + +/// Trait for managing storage across multiple MCMC chains. +/// +/// Defines the interface for initializing chain storage and combining results +/// from multiple chains into a final result. +pub trait TraceStorage: Send + Sync + Sized + 'static { + /// The storage type for individual chains. + type ChainStorage: ChainStorage; + + /// The final result type combining all chains. + type Finalized: Send + Sync + 'static; + + /// Create storage for a single chain. + fn initialize_trace_for_chain(&self, chain_id: u64) -> Result; + + /// Combine results from all chains into final output. + /// + /// # Arguments + /// + /// * `traces` - Finalized results from all chains + fn finalize( + self, + traces: Vec::Finalized>>, + ) -> Result<(Option, Self::Finalized)>; +} diff --git a/src/storage/zarr/async_impl.rs b/src/storage/zarr/async_impl.rs new file mode 100644 index 0000000..acde6de --- /dev/null +++ b/src/storage/zarr/async_impl.rs @@ -0,0 +1,657 @@ +use std::collections::HashMap; +use std::iter::once; +use std::sync::Arc; +use tokio::task::JoinHandle; + +use anyhow::{Context, Result}; +use nuts_storable::{ItemType, Value}; +use zarrs::array::{ArrayBuilder, DataType, FillValue}; +use zarrs::array_subset::ArraySubset; +use zarrs::group::GroupBuilder; +use zarrs::storage::{ + AsyncReadableWritableListableStorage, AsyncReadableWritableListableStorageTraits, +}; + +use super::common::{Chunk, SampleBuffer, SampleBufferValue, create_arrays}; +use crate::storage::{ChainStorage, StorageConfig, TraceStorage}; +use crate::{Math, Progress, Settings}; + +pub type Array = Arc>; + +struct ArrayCollection { + pub warmup_param_arrays: HashMap, + pub sample_param_arrays: HashMap, + pub warmup_draw_arrays: HashMap, + pub sample_draw_arrays: HashMap, +} + +/// Main storage for async Zarr MCMC traces +pub struct ZarrAsyncTraceStorage { + arrays: Arc, + draw_chunk_size: u64, + param_types: Vec<(String, ItemType)>, + draw_types: Vec<(String, ItemType)>, + rt_handle: tokio::runtime::Handle, +} + +/// Per-chain storage for async Zarr MCMC traces +pub struct ZarrAsyncChainStorage { + draw_buffers: HashMap, + stats_buffers: HashMap, + arrays: Arc, + chain: u64, + last_sample_was_warmup: bool, + pending_writes: Vec>>, + rt_handle: tokio::runtime::Handle, +} + +/// Write a chunk of data to a Zarr array asynchronously +async fn store_zarr_chunk_async(array: Array, data: Chunk, chain_chunk_index: u64) -> Result<()> { + let rank = array.chunk_grid().dimensionality(); + assert!(rank >= 2); + // append one value per rank + let chunk_vec: Vec<_> = once(chain_chunk_index as u64) + .chain(once(data.chunk_idx as u64)) + .chain(once(0).cycle().take(rank - 2)) + .collect(); + let chunk = &chunk_vec[..]; + + let result = if data.is_full() { + match data.values { + SampleBufferValue::F64(v) => array.async_store_chunk_elements::(&chunk, &v).await, + SampleBufferValue::F32(v) => array.async_store_chunk_elements::(&chunk, &v).await, + SampleBufferValue::U64(v) => array.async_store_chunk_elements::(&chunk, &v).await, + SampleBufferValue::I64(v) => array.async_store_chunk_elements::(&chunk, &v).await, + SampleBufferValue::Bool(v) => { + array.async_store_chunk_elements::(&chunk, &v).await + } + } + } else { + let mut shape: Vec<_> = array.shape().iter().cloned().collect(); + assert!(shape.len() >= 2); + shape[0] = 1; + shape[1] = data.len as u64; + let chunk_subset = ArraySubset::new_with_shape(shape); + match data.values { + SampleBufferValue::F64(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array + .async_store_chunk_subset_elements(&chunk, &chunk_subset, &v) + .await + } + SampleBufferValue::F32(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array + .async_store_chunk_subset_elements(&chunk, &chunk_subset, &v) + .await + } + SampleBufferValue::U64(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array + .async_store_chunk_subset_elements(&chunk, &chunk_subset, &v) + .await + } + SampleBufferValue::I64(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array + .async_store_chunk_subset_elements(&chunk, &chunk_subset, &v) + .await + } + SampleBufferValue::Bool(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array + .async_store_chunk_subset_elements(&chunk, &chunk_subset, &v) + .await + } + } + }; + + result.context(format!( + "Failed to store chunk for variable {} at chunk {} with length {}", + array.path(), + data.chunk_idx, + data.len + ))?; + Ok(()) +} + +/// Store a chunk synchronously by blocking on the async operation +fn store_zarr_chunk_sync( + handle: &tokio::runtime::Handle, + array: &Array, + data: Chunk, + chain_chunk_index: u64, +) -> Result<()> { + let array = array.clone(); + handle.block_on(async move { + tokio::runtime::Handle::current().block_on(store_zarr_chunk_async( + array, + data, + chain_chunk_index, + )) + }) +} + +/// Store coordinates in zarr arrays +async fn store_coords( + store: AsyncReadableWritableListableStorage, + group: String, + coords: &HashMap, +) -> Result<()> { + for (name, coord) in coords { + let (data_type, len, fill_value) = match coord { + &Value::F64(ref v) => (DataType::Float64, v.len(), FillValue::from(f64::NAN)), + &Value::F32(ref v) => (DataType::Float32, v.len(), FillValue::from(f32::NAN)), + &Value::U64(ref v) => (DataType::UInt64, v.len(), FillValue::from(0u64)), + &Value::I64(ref v) => (DataType::Int64, v.len(), FillValue::from(0i64)), + &Value::Bool(ref v) => (DataType::Bool, v.len(), FillValue::from(false)), + &Value::Strings(ref v) => (DataType::String, v.len(), FillValue::from("")), + _ => panic!("Unsupported coordinate type for {}", name), + }; + let name: &String = name; + let coord_array = ArrayBuilder::new( + vec![len as u64], + data_type, + vec![len as u64].try_into().expect("Invalid chunk size"), + fill_value, + ) + .dimension_names(Some(vec![name.to_string()])) + .build(store.clone(), &format!("{}/{}", group, name))?; + let subset = vec![0]; + match coord { + &Value::F64(ref v) => { + coord_array + .async_store_chunk_elements::(&subset, v) + .await? + } + &Value::F32(ref v) => { + coord_array + .async_store_chunk_elements::(&subset, v) + .await? + } + &Value::U64(ref v) => { + coord_array + .async_store_chunk_elements::(&subset, v) + .await? + } + &Value::I64(ref v) => { + coord_array + .async_store_chunk_elements::(&subset, v) + .await? + } + &Value::Bool(ref v) => { + coord_array + .async_store_chunk_elements::(&subset, v) + .await? + } + &Value::Strings(ref v) => { + coord_array + .async_store_chunk_elements::(&subset, v) + .await? + } + _ => unreachable!(), + } + coord_array.async_store_metadata().await?; + } + Ok(()) +} + +impl ZarrAsyncChainStorage { + /// Create a new chain storage with buffers for parameters and samples + fn new( + arrays: Arc, + param_types: &Vec<(String, ItemType)>, + draw_types: &Vec<(String, ItemType)>, + buffer_size: u64, + chain: u64, + rt_handle: tokio::runtime::Handle, + ) -> Self { + let draw_buffers = draw_types + .iter() + .map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size))) + .collect(); + + let stats_buffers = param_types + .iter() + .map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size))) + .collect(); + Self { + draw_buffers, + stats_buffers, + arrays, + chain, + last_sample_was_warmup: true, + pending_writes: Vec::new(), + rt_handle, + } + } + + /// Store a parameter value, spawning async write when buffer is full + fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + let Some(buffer) = self.stats_buffers.get_mut(name) else { + panic!("Unknown param name: {}", name); + }; + if let Some(chunk) = buffer.push(value) { + let array = if is_warmup { + self.arrays.warmup_param_arrays[name].clone() + } else { + self.arrays.sample_param_arrays[name].clone() + }; + let chain = self.chain; + let handle = self + .rt_handle + .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); + self.pending_writes.push(handle); + } + Ok(()) + } + + /// Store a draw value, spawning async write when buffer is full + fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + let Some(buffer) = self.draw_buffers.get_mut(name) else { + panic!("Unknown posterior variable name: {}", name); + }; + if let Some(chunk) = buffer.push(value) { + let array = if is_warmup { + self.arrays.warmup_draw_arrays[name].clone() + } else { + self.arrays.sample_draw_arrays[name].clone() + }; + let chain = self.chain; + let handle = self + .rt_handle + .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); + self.pending_writes.push(handle); + } + Ok(()) + } +} + +impl ChainStorage for ZarrAsyncChainStorage { + type Finalized = (); + + fn record_sample( + &mut self, + _settings: &impl Settings, + stats: Vec<(&str, Option)>, + draws: Vec<(&str, Option)>, + info: &Progress, + ) -> Result<()> { + let is_first_draw = self.last_sample_was_warmup && !info.tuning; + if is_first_draw { + for (key, buffer) in self.draw_buffers.iter_mut() { + if let Some(chunk) = buffer.reset() { + let array = self.arrays.warmup_draw_arrays[key].clone(); + let chain = self.chain; + let handle = self + .rt_handle + .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); + self.pending_writes.push(handle); + } + } + for (key, buffer) in self.stats_buffers.iter_mut() { + if let Some(chunk) = buffer.reset() { + let array = self.arrays.warmup_param_arrays[key].clone(); + let chain = self.chain; + let handle = self + .rt_handle + .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); + self.pending_writes.push(handle); + } + } + self.last_sample_was_warmup = false; + } + + for (name, value) in stats { + if let Some(value) = value { + self.push_param(name, value, info.tuning)?; + } + } + for (name, value) in draws { + if let Some(value) = value { + self.push_draw(name, value, info.tuning)?; + } else { + panic!("Missing draw value for {}", name); + } + } + Ok(()) + } + + /// Flush remaining samples and finalize storage, joining all pending writes + fn finalize(self) -> Result { + // Handle remaining buffers synchronously + for (key, mut buffer) in self.draw_buffers.into_iter() { + if let Some(chunk) = buffer.reset() { + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_draw_arrays[&key] + } else { + &self.arrays.sample_draw_arrays[&key] + }; + store_zarr_chunk_sync(&self.rt_handle, array, chunk, self.chain)?; + } + } + for (key, mut buffer) in self.stats_buffers.into_iter() { + if let Some(chunk) = buffer.reset() { + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_param_arrays[&key] + } else { + &self.arrays.sample_param_arrays[&key] + }; + store_zarr_chunk_sync(&self.rt_handle, array, chunk, self.chain)?; + } + } + + // Join all pending writes + self.rt_handle.block_on(async move { + for join_handle in self.pending_writes { + let _ = join_handle + .await + .context("Failed to await async chunk write operation")?; + } + Ok::<(), anyhow::Error>(()) + })?; + + Ok(()) + } + + /// Write current buffer contents to storage without modifying the buffers + fn flush(&self) -> Result<()> { + // Flush all draw buffers that have data (synchronously) + for (key, buffer) in &self.draw_buffers { + if let Some(temp_chunk) = buffer.copy_as_chunk() { + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_draw_arrays[key] + } else { + &self.arrays.sample_draw_arrays[key] + }; + store_zarr_chunk_sync(&self.rt_handle, array, temp_chunk, self.chain)?; + } + } + + // Flush all stats buffers that have data (synchronously) + for (key, buffer) in &self.stats_buffers { + if let Some(temp_chunk) = buffer.copy_as_chunk() { + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_param_arrays[key] + } else { + &self.arrays.sample_param_arrays[key] + }; + store_zarr_chunk_sync(&self.rt_handle, array, temp_chunk, self.chain)?; + } + } + + Ok(()) + } +} + +/// Configuration for async Zarr-based MCMC storage. +/// +/// This is the async version of ZarrConfig that uses tokio for async I/O operations. +/// It provides the same interface but spawns tasks for write operations to avoid +/// blocking the sampling process. +/// +/// The storage organizes data into groups: +/// - `posterior/` - posterior samples +/// - `sample_stats/` - sampling statistics +/// - `warmup_posterior/` - warmup samples (optional) +/// - `warmup_sample_stats/` - warmup statistics (optional) +pub struct ZarrAsyncConfig { + store: AsyncReadableWritableListableStorage, + group_path: Option, + draw_chunk_size: u64, + store_warmup: bool, + rt_handle: tokio::runtime::Handle, +} + +impl ZarrAsyncConfig { + /// Create a new async Zarr configuration with default settings. + /// + /// Default settings: + /// - `draw_chunk_size`: 100 samples per chunk + /// - `store_warmup`: true (warmup samples are stored) + /// - `group_path`: root of the store + pub fn new( + rt_handle: tokio::runtime::Handle, + store: AsyncReadableWritableListableStorage, + ) -> Self { + Self { + store, + group_path: None, + draw_chunk_size: 100, + store_warmup: true, + rt_handle, + } + } + + /// Set the number of samples per chunk. + /// + /// Larger chunks use more memory but may provide better I/O performance. + /// Smaller chunks provide more frequent flushing and lower memory usage. + pub fn with_chunk_size(mut self, chunk_size: u64) -> Self { + self.draw_chunk_size = chunk_size; + self + } + + /// Set the group path within the Zarr store. + /// + /// If not set, data is stored at the root of the store. + pub fn with_group_path>(mut self, path: S) -> Self { + self.group_path = Some(path.into()); + self + } + + /// Configure whether to store warmup samples. + /// + /// When true, warmup samples are stored in separate groups. + /// When false, only post-warmup samples are stored. + pub fn store_warmup(mut self, store: bool) -> Self { + self.store_warmup = store; + self + } +} + +impl StorageConfig for ZarrAsyncConfig { + type Storage = ZarrAsyncTraceStorage; + + fn new_trace(self, settings: &impl Settings, math: &M) -> Result { + let handle = self.rt_handle.clone(); + let rt_handle = handle.clone(); + handle.block_on(async move { + let n_chains = settings.num_chains() as u64; + let n_tune = settings.hint_num_tune() as u64; + let n_draws = settings.hint_num_draws() as u64; + + let param_types = settings.stat_types(math); + let draw_types = settings.data_types(math); + + let param_dims = settings.stat_dims_all(math); + let draw_dims = settings.data_dims_all(math); + + let draw_dim_sizes = math.dim_sizes(); + let stat_dim_sizes = settings.stat_dim_sizes(math); + + let mut group_path = self.group_path.unwrap_or_else(|| "".to_string()); + if !group_path.ends_with('/') { + group_path.push('/'); + } + let store = self.store; + let draw_chunk_size = self.draw_chunk_size; + + let mut root = GroupBuilder::new().build(store.clone(), &group_path)?; + + let attrs = root.attributes_mut(); + attrs.insert( + "sampler".to_string(), + serde_json::Value::String(env!("CARGO_PKG_NAME").to_string()), + ); + attrs.insert( + "sampler_version".to_string(), + serde_json::Value::String(env!("CARGO_PKG_VERSION").to_string()), + ); + attrs.insert( + "sampler_settings".to_string(), + serde_json::to_value(settings).context("Could not serialize sampler settings")?, + ); + root.async_store_metadata().await?; + + GroupBuilder::new() + .build(store.clone(), &format!("{}warmup_posterior", group_path))? + .async_store_metadata() + .await?; + GroupBuilder::new() + .build(store.clone(), &format!("{}warmup_sample_stats", group_path))? + .async_store_metadata() + .await?; + GroupBuilder::new() + .build(store.clone(), &format!("{}posterior", group_path))? + .async_store_metadata() + .await?; + GroupBuilder::new() + .build(store.clone(), &format!("{}sample_stats", group_path))? + .async_store_metadata() + .await?; + + let warmup_param_arrays = create_arrays( + store.clone(), + &format!("{}warmup_sample_stats", group_path), + ¶m_types, + ¶m_dims, + n_chains, + n_tune, + &stat_dim_sizes, + self.draw_chunk_size, + )?; + let sample_param_arrays = create_arrays( + store.clone(), + &format!("{}sample_stats", group_path), + ¶m_types, + ¶m_dims, + n_chains, + n_draws, + &stat_dim_sizes, + self.draw_chunk_size, + )?; + let warmup_draw_arrays = create_arrays( + store.clone(), + &format!("{}warmup_posterior", group_path), + &draw_types, + &draw_dims, + n_chains, + n_tune, + &draw_dim_sizes, + self.draw_chunk_size, + )?; + let sample_draw_arrays = create_arrays( + store.clone(), + &format!("{}posterior", group_path), + &draw_types, + &draw_dims, + n_chains, + n_draws, + &draw_dim_sizes, + self.draw_chunk_size, + )?; + // add arc around each value + let warmup_param_arrays: HashMap<_, _> = warmup_param_arrays + .into_iter() + .map(|(k, v)| (k, Arc::new(v) as Array)) + .collect(); + let sample_param_arrays: HashMap<_, _> = sample_param_arrays + .into_iter() + .map(|(k, v)| (k, Arc::new(v) as Array)) + .collect(); + let warmup_draw_arrays: HashMap<_, _> = warmup_draw_arrays + .into_iter() + .map(|(k, v)| (k, Arc::new(v) as Array)) + .collect(); + let sample_draw_arrays: HashMap<_, _> = sample_draw_arrays + .into_iter() + .map(|(k, v)| (k, Arc::new(v) as Array)) + .collect(); + for array in warmup_param_arrays + .values() + .chain(sample_param_arrays.values()) + .chain(warmup_draw_arrays.values()) + .chain(sample_draw_arrays.values()) + { + array.async_store_metadata().await?; + } + let trace_storage = ArrayCollection { + warmup_param_arrays, + sample_param_arrays, + warmup_draw_arrays, + sample_draw_arrays, + }; + + let draw_coords = math.coords(); + let stat_coords = settings.stat_coords(math); + + store_coords( + store.clone(), + format!("{}posterior", &group_path), + &draw_coords, + ) + .await?; + store_coords( + store.clone(), + format!("{}warmup_posterior", &group_path), + &draw_coords, + ) + .await?; + store_coords( + store.clone(), + format!("{}sample_stats", &group_path), + &stat_coords, + ) + .await?; + store_coords( + store.clone(), + format!("{}warmup_sample_stats", &group_path), + &stat_coords, + ) + .await?; + Ok(ZarrAsyncTraceStorage { + arrays: Arc::new(trace_storage), + param_types, + draw_types, + draw_chunk_size, + rt_handle, + }) + }) + } +} + +impl TraceStorage for ZarrAsyncTraceStorage { + type ChainStorage = ZarrAsyncChainStorage; + + type Finalized = (); + + fn initialize_trace_for_chain(&self, chain_id: u64) -> Result { + Ok(ZarrAsyncChainStorage::new( + self.arrays.clone(), + &self.param_types, + &self.draw_types, + self.draw_chunk_size, + chain_id as _, + self.rt_handle.clone(), + )) + } + + fn finalize( + self, + traces: Vec::Finalized>>, + ) -> Result<(Option, Self::Finalized)> { + for trace in traces { + if let Err(err) = trace { + return Ok((Some(err), ())); + } + } + Ok((None, ())) + } +} diff --git a/src/storage/zarr/common.rs b/src/storage/zarr/common.rs new file mode 100644 index 0000000..734e5b0 --- /dev/null +++ b/src/storage/zarr/common.rs @@ -0,0 +1,224 @@ +use std::collections::HashMap; +use std::mem::replace; +use std::sync::Arc; + +use anyhow::Result; +use nuts_storable::{ItemType, Value}; +use zarrs::array::{Array, ArrayBuilder, DataType, FillValue}; + +/// Container for different types of sample values +#[derive(Clone, Debug)] +pub enum SampleBufferValue { + F64(Vec), + F32(Vec), + Bool(Vec), + I64(Vec), + U64(Vec), +} + +/// Buffer for collecting samples before writing to storage +pub struct SampleBuffer { + pub items: SampleBufferValue, + pub len: usize, + pub full_at: usize, + pub current_chunk: usize, +} + +/// A chunk of samples ready to be written to storage +#[derive(Debug)] +pub struct Chunk { + pub chunk_idx: usize, + pub len: usize, + pub full_at: usize, + pub values: SampleBufferValue, +} + +impl Chunk { + /// Check if the chunk has reached its capacity + pub fn is_full(&self) -> bool { + self.full_at == self.len + } +} + +impl SampleBuffer { + /// Create a new sample buffer with specified type and chunk size + pub fn new(item_type: ItemType, chunk_size: u64) -> Self { + let chunk_size = chunk_size.try_into().expect("Chunk size too large"); + let inner = match item_type { + ItemType::F64 => SampleBufferValue::F64(Vec::with_capacity(chunk_size)), + ItemType::F32 => SampleBufferValue::F32(Vec::with_capacity(chunk_size)), + ItemType::U64 => SampleBufferValue::U64(Vec::with_capacity(chunk_size)), + ItemType::Bool => SampleBufferValue::Bool(Vec::with_capacity(chunk_size)), + ItemType::I64 => SampleBufferValue::I64(Vec::with_capacity(chunk_size)), + ItemType::String => panic!("String type not supported in SampleBuffer"), + }; + Self { + items: inner, + len: 0, + full_at: chunk_size, + current_chunk: 0, + } + } + + /// Reset the buffer and return any accumulated data as a chunk + pub fn reset(&mut self) -> Option { + if self.len == 0 { + self.current_chunk = 0; + return None; + } + let out = self.finish_chunk(); + self.current_chunk = 0; + Some(out) + } + + /// Finalize the current chunk and prepare for a new one + pub fn finish_chunk(&mut self) -> Chunk { + let values = match &mut self.items { + SampleBufferValue::F64(vec) => { + SampleBufferValue::F64(replace(vec, Vec::with_capacity(vec.len()))) + } + SampleBufferValue::F32(vec) => { + SampleBufferValue::F32(replace(vec, Vec::with_capacity(vec.len()))) + } + SampleBufferValue::U64(vec) => { + SampleBufferValue::U64(replace(vec, Vec::with_capacity(vec.len()))) + } + SampleBufferValue::Bool(vec) => { + SampleBufferValue::Bool(replace(vec, Vec::with_capacity(vec.len()))) + } + SampleBufferValue::I64(vec) => { + SampleBufferValue::I64(replace(vec, Vec::with_capacity(vec.len()))) + } + }; + + let output = Chunk { + chunk_idx: self.current_chunk, + len: self.len, + values, + full_at: self.full_at, + }; + + self.current_chunk += 1; + self.len = 0; + output + } + + /// Creates a temporary chunk containing a copy of the current buffer's data + pub fn copy_as_chunk(&self) -> Option { + if self.len == 0 { + return None; + } + + let values = match &self.items { + SampleBufferValue::F64(vec) => SampleBufferValue::F64(vec.clone()), + SampleBufferValue::F32(vec) => SampleBufferValue::F32(vec.clone()), + SampleBufferValue::U64(vec) => SampleBufferValue::U64(vec.clone()), + SampleBufferValue::Bool(vec) => SampleBufferValue::Bool(vec.clone()), + SampleBufferValue::I64(vec) => SampleBufferValue::I64(vec.clone()), + }; + + Some(Chunk { + chunk_idx: self.current_chunk, + len: self.len, + values, + full_at: self.full_at, + }) + } + + /// Add an item to the buffer, returning a chunk if buffer becomes full + pub fn push(&mut self, item: Value) -> Option { + assert!(self.len < self.full_at); + match (&mut self.items, item) { + (SampleBufferValue::F64(vec), Value::ScalarF64(v)) => vec.push(v), + (SampleBufferValue::F32(vec), Value::ScalarF32(v)) => vec.push(v), + (SampleBufferValue::U64(vec), Value::ScalarU64(v)) => vec.push(v), + (SampleBufferValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v), + (SampleBufferValue::I64(vec), Value::ScalarI64(v)) => vec.push(v), + (SampleBufferValue::F64(vec), Value::F64(v)) => vec.extend(v.into_iter()), + (SampleBufferValue::F32(vec), Value::F32(v)) => vec.extend(v.into_iter()), + (SampleBufferValue::U64(vec), Value::U64(v)) => vec.extend(v.into_iter()), + (SampleBufferValue::Bool(vec), Value::Bool(v)) => vec.extend(v.into_iter()), + (SampleBufferValue::I64(vec), Value::I64(v)) => vec.extend(v.into_iter()), + _ => panic!("Mismatched item type"), + } + self.len += 1; + + if self.len == self.full_at { + Some(self.finish_chunk()) + } else { + None + } + } +} + +/// Create Zarr arrays for storing MCMC trace data +pub fn create_arrays( + store: Arc, + group_path: &str, + item_types: &Vec<(String, ItemType)>, + item_dims: &Vec<(String, Vec)>, + n_chains: u64, + n_draws: u64, + dim_sizes: &HashMap, + draw_chunk_size: u64, +) -> Result>> { + let mut arrays = HashMap::new(); + for ((name1, item_type), (name2, extra_dims)) in item_types.iter().zip(item_dims.iter()) { + assert!(name1 == name2); + let name = name1; + if ["draw", "chain"].contains(&name.as_str()) { + continue; + } + let dims = std::iter::once("chain".to_string()) + .chain(std::iter::once("draw".to_string())) + .chain(extra_dims.iter().cloned()); + let extra_shape: Result> = extra_dims + .iter() + .map(|dim| { + dim_sizes + .get(dim) + .ok_or_else(|| { + anyhow::anyhow!("Unknown dimension size for dimension {}", dim) + .context(format!("Could not write {}/{}", group_path, name)) + }) + .map(|size| *size) + }) + .collect(); + let extra_shape = extra_shape?; + let shape: Vec = std::iter::once(n_chains as u64) + .chain(std::iter::once(n_draws as u64)) + .chain(extra_shape.clone()) + .collect(); + let zarr_type = match item_type { + ItemType::F64 => DataType::Float64, + ItemType::F32 => DataType::Float32, + ItemType::U64 => DataType::UInt64, + ItemType::I64 => DataType::Int64, + ItemType::Bool => DataType::Bool, + ItemType::String => DataType::String, + }; + let fill_value = match item_type { + ItemType::F64 => FillValue::from(f64::NAN), + ItemType::F32 => FillValue::from(f32::NAN), + ItemType::U64 => FillValue::from(0u64), + ItemType::I64 => FillValue::from(0i64), + ItemType::Bool => FillValue::from(false), + ItemType::String => FillValue::from(""), + }; + let grid: Vec = std::iter::once(1) + .chain(std::iter::once(draw_chunk_size)) + .chain(extra_shape) + .collect(); + let array = ArrayBuilder::new( + shape, + zarr_type, + grid.try_into().expect("Invalid chunk sizes"), + fill_value, + ) + .dimension_names(Some(dims)) + .build(store.clone(), &format!("{}/{}", group_path, name))?; + //array.store_metadata()?; + arrays.insert(name.to_string(), array); + } + Ok(arrays) +} diff --git a/src/storage/zarr/mod.rs b/src/storage/zarr/mod.rs new file mode 100644 index 0000000..be1df57 --- /dev/null +++ b/src/storage/zarr/mod.rs @@ -0,0 +1,9 @@ +pub mod common; +pub mod sync_impl; + +pub mod async_impl; + +pub use common::*; +pub use sync_impl::*; + +pub use async_impl::*; diff --git a/src/storage/zarr/sync_impl.rs b/src/storage/zarr/sync_impl.rs new file mode 100644 index 0000000..fb10d5c --- /dev/null +++ b/src/storage/zarr/sync_impl.rs @@ -0,0 +1,537 @@ +use std::collections::HashMap; +use std::iter::once; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use nuts_storable::{ItemType, Value}; +use zarrs::array::{ArrayBuilder, DataType, FillValue}; +use zarrs::array_subset::ArraySubset; +use zarrs::group::GroupBuilder; +use zarrs::storage::{ReadableWritableListableStorage, ReadableWritableListableStorageTraits}; + +use super::common::{Chunk, SampleBuffer, SampleBufferValue}; +use super::create_arrays; +use crate::storage::{ChainStorage, StorageConfig, TraceStorage}; +use crate::{Math, Progress, Settings}; + +pub type Array = zarrs::array::Array; + +struct ArrayCollection { + pub warmup_param_arrays: HashMap, + pub sample_param_arrays: HashMap, + pub warmup_draw_arrays: HashMap, + pub sample_draw_arrays: HashMap, +} + +/// Store coordinates in zarr arrays +pub fn store_coords( + store: ReadableWritableListableStorage, + group: String, + coords: &HashMap, +) -> Result<()> { + for (name, coord) in coords { + let (data_type, len, fill_value) = match coord { + &Value::F64(ref v) => (DataType::Float64, v.len(), FillValue::from(f64::NAN)), + &Value::F32(ref v) => (DataType::Float32, v.len(), FillValue::from(f32::NAN)), + &Value::U64(ref v) => (DataType::UInt64, v.len(), FillValue::from(0u64)), + &Value::I64(ref v) => (DataType::Int64, v.len(), FillValue::from(0i64)), + &Value::Bool(ref v) => (DataType::Bool, v.len(), FillValue::from(false)), + &Value::Strings(ref v) => (DataType::String, v.len(), FillValue::from("")), + _ => panic!("Unsupported coordinate type for {}", name), + }; + let name: &String = name; + let coord_array = ArrayBuilder::new( + vec![len as u64], + data_type, + vec![len as u64].try_into().expect("Invalid chunk size"), + fill_value, + ) + .dimension_names(Some(vec![name.to_string()])) + .build(store.clone(), &format!("{}/{}", group, name))?; + let subset = vec![0]; + match coord { + &Value::F64(ref v) => coord_array.store_chunk_elements::(&subset, v)?, + &Value::F32(ref v) => coord_array.store_chunk_elements::(&subset, v)?, + &Value::U64(ref v) => coord_array.store_chunk_elements::(&subset, v)?, + &Value::I64(ref v) => coord_array.store_chunk_elements::(&subset, v)?, + &Value::Bool(ref v) => coord_array.store_chunk_elements::(&subset, v)?, + &Value::Strings(ref v) => coord_array.store_chunk_elements::(&subset, v)?, + _ => unreachable!(), + } + coord_array.store_metadata()?; + } + Ok(()) +} + +/// Main storage for Zarr MCMC traces +pub struct ZarrTraceStorage { + arrays: Arc, + draw_chunk_size: u64, + param_types: Vec<(String, ItemType)>, + draw_types: Vec<(String, ItemType)>, +} + +/// Per-chain storage for Zarr MCMC traces +pub struct ZarrChainStorage { + draw_buffers: HashMap, + stats_buffers: HashMap, + arrays: Arc, + chain: u64, + last_sample_was_warmup: bool, +} + +/// Write a chunk of data to a Zarr array +fn store_zarr_chunk(array: &Array, data: Chunk, chain_chunk_index: u64) -> Result<()> { + let rank = array.chunk_grid().dimensionality(); + assert!(rank >= 2); + // append one value per rank + let chunk_vec: Vec<_> = once(chain_chunk_index as u64) + .chain(once(data.chunk_idx as u64)) + .chain(once(0).cycle().take(rank - 2)) + .collect(); + let chunk = &chunk_vec[..]; + + let result = if data.is_full() { + match data.values { + SampleBufferValue::F64(v) => array.store_chunk_elements::(&chunk, &v), + SampleBufferValue::F32(v) => array.store_chunk_elements::(&chunk, &v), + SampleBufferValue::U64(v) => array.store_chunk_elements::(&chunk, &v), + SampleBufferValue::I64(v) => array.store_chunk_elements::(&chunk, &v), + SampleBufferValue::Bool(v) => array.store_chunk_elements::(&chunk, &v), + } + } else { + let mut shape: Vec<_> = array.shape().iter().cloned().collect(); + assert!(shape.len() >= 2); + shape[0] = 1; + shape[1] = data.len as u64; + let chunk_subset = ArraySubset::new_with_shape(shape); + match data.values { + SampleBufferValue::F64(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array.store_chunk_subset_elements(&chunk, &chunk_subset, &v) + } + SampleBufferValue::F32(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array.store_chunk_subset_elements(&chunk, &chunk_subset, &v) + } + SampleBufferValue::U64(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array.store_chunk_subset_elements(&chunk, &chunk_subset, &v) + } + SampleBufferValue::I64(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array.store_chunk_subset_elements(&chunk, &chunk_subset, &v) + } + SampleBufferValue::Bool(v) => { + assert!(v.len() == chunk_subset.num_elements_usize()); + array.store_chunk_subset_elements(&chunk, &chunk_subset, &v) + } + } + }; + + result.context(format!( + "Failed to store chunk for variable {} at chunk {} with length {}", + array.path(), + data.chunk_idx, + data.len + ))?; + Ok(()) +} + +impl ZarrChainStorage { + /// Create a new chain storage with buffers for parameters and samples + fn new( + arrays: Arc, + param_types: &Vec<(String, ItemType)>, + draw_types: &Vec<(String, ItemType)>, + buffer_size: u64, + chain: u64, + ) -> Self { + let draw_buffers = draw_types + .iter() + .map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size))) + .collect(); + + let stats_buffers = param_types + .iter() + .map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size))) + .collect(); + Self { + draw_buffers, + stats_buffers, + arrays, + chain, + last_sample_was_warmup: true, + } + } + + /// Store a parameter value, writing to Zarr when buffer is full + fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + let Some(buffer) = self.stats_buffers.get_mut(name) else { + panic!("Unknown param name: {}", name); + }; + if let Some(chunk) = buffer.push(value) { + let array = if is_warmup { + &self.arrays.warmup_param_arrays[name] + } else { + &self.arrays.sample_param_arrays[name] + }; + store_zarr_chunk(array, chunk, self.chain)?; + } + Ok(()) + } + + /// Store a draw value, writing to Zarr when buffer is full + fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> { + if ["draw", "chain"].contains(&name) { + return Ok(()); + } + let Some(buffer) = self.draw_buffers.get_mut(name) else { + panic!("Unknown posterior variable name: {}", name); + }; + if let Some(chunk) = buffer.push(value) { + let array = if is_warmup { + &self.arrays.warmup_draw_arrays[name] + } else { + &self.arrays.sample_draw_arrays[name] + }; + store_zarr_chunk(array, chunk, self.chain)?; + } + Ok(()) + } +} + +impl ChainStorage for ZarrChainStorage { + type Finalized = (); + + fn record_sample( + &mut self, + _settings: &impl Settings, + stats: Vec<(&str, Option)>, + draws: Vec<(&str, Option)>, + info: &Progress, + ) -> Result<()> { + let is_first_draw = self.last_sample_was_warmup && !info.tuning; + if is_first_draw { + for (key, buffer) in self.draw_buffers.iter_mut() { + if let Some(chunk) = buffer.reset() { + store_zarr_chunk(&self.arrays.warmup_draw_arrays[key], chunk, self.chain)?; + } + } + for (key, buffer) in self.stats_buffers.iter_mut() { + if let Some(chunk) = buffer.reset() { + store_zarr_chunk(&self.arrays.warmup_param_arrays[key], chunk, self.chain)?; + } + } + self.last_sample_was_warmup = false; + } + + for (name, value) in stats { + if let Some(value) = value { + self.push_param(name, value, info.tuning)?; + } + } + for (name, value) in draws { + if let Some(value) = value { + self.push_draw(name, value, info.tuning)?; + } else { + panic!("Missing draw value for {}", name); + } + } + Ok(()) + } + + /// Flush remaining samples and finalize storage + fn finalize(self) -> Result { + for (key, mut buffer) in self.draw_buffers.into_iter() { + if let Some(chunk) = buffer.reset() { + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_draw_arrays[&key] + } else { + &self.arrays.sample_draw_arrays[&key] + }; + store_zarr_chunk(array, chunk, self.chain)?; + } + } + for (key, mut buffer) in self.stats_buffers.into_iter() { + if let Some(chunk) = buffer.reset() { + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_param_arrays[&key] + } else { + &self.arrays.sample_param_arrays[&key] + }; + store_zarr_chunk(array, chunk, self.chain)?; + } + } + Ok(()) + } + + /// Write current buffer contents to storage without modifying the buffers + fn flush(&self) -> Result<()> { + // Flush all draw buffers that have data + for (key, buffer) in &self.draw_buffers { + if let Some(temp_chunk) = buffer.copy_as_chunk() { + // Store the temporary chunk + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_draw_arrays[key] + } else { + &self.arrays.sample_draw_arrays[key] + }; + store_zarr_chunk(array, temp_chunk, self.chain)?; + } + } + + // Flush all stats buffers that have data + for (key, buffer) in &self.stats_buffers { + if let Some(temp_chunk) = buffer.copy_as_chunk() { + // Store the temporary chunk + let array = if self.last_sample_was_warmup { + &self.arrays.warmup_param_arrays[key] + } else { + &self.arrays.sample_param_arrays[key] + }; + store_zarr_chunk(array, temp_chunk, self.chain)?; + } + } + + Ok(()) + } +} + +/// Configuration for Zarr-based MCMC storage. +/// +/// This is the main interface for configuring Zarr storage for MCMC sampling. +/// Zarr provides efficient, chunked storage for large datasets with good +/// compression and parallel I/O support. +/// +/// The storage organizes data into groups: +/// - `posterior/` - posterior samples +/// - `sample_stats/` - sampling statistics +/// - `warmup_posterior/` - warmup samples (optional) +/// - `warmup_sample_stats/` - warmup statistics (optional) +pub struct ZarrConfig { + store: ReadableWritableListableStorage, + group_path: Option, + draw_chunk_size: u64, + store_warmup: bool, +} + +impl ZarrConfig { + /// Create a new Zarr configuration with default settings. + /// + /// Default settings: + /// - `draw_chunk_size`: 100 samples per chunk + /// - `store_warmup`: true (warmup samples are stored) + /// - `group_path`: root of the store + pub fn new(store: ReadableWritableListableStorage) -> Self { + Self { + store, + group_path: None, + draw_chunk_size: 100, + store_warmup: true, + } + } + + /// Set the number of samples per chunk. + /// + /// Larger chunks use more memory but may provide better I/O performance. + /// Smaller chunks provide more frequent flushing and lower memory usage. + pub fn with_chunk_size(mut self, chunk_size: u64) -> Self { + self.draw_chunk_size = chunk_size; + self + } + + /// Set the group path within the Zarr store. + /// + /// If not set, data is stored at the root of the store. + pub fn with_group_path>(mut self, path: S) -> Self { + self.group_path = Some(path.into()); + self + } + + /// Configure whether to store warmup samples. + /// + /// When true, warmup samples are stored in separate groups. + /// When false, only post-warmup samples are stored. + pub fn store_warmup(mut self, store: bool) -> Self { + self.store_warmup = store; + self + } +} + +impl StorageConfig for ZarrConfig { + type Storage = ZarrTraceStorage; + + fn new_trace(self, settings: &impl Settings, math: &M) -> Result { + let n_chains = settings.num_chains() as u64; + let n_tune = settings.hint_num_tune() as u64; + let n_draws = settings.hint_num_draws() as u64; + + let param_types = settings.stat_types(math); + let draw_types = settings.data_types(math); + + let param_dims = settings.stat_dims_all(math); + let draw_dims = settings.data_dims_all(math); + + let draw_dim_sizes = math.dim_sizes(); + let stat_dim_sizes = settings.stat_dim_sizes(math); + + let mut group_path = self.group_path.unwrap_or_else(|| "".to_string()); + if !group_path.ends_with('/') { + group_path.push('/'); + } + let store = self.store; + let draw_chunk_size = self.draw_chunk_size; + + let mut root = GroupBuilder::new().build(store.clone(), &group_path)?; + + let attrs = root.attributes_mut(); + attrs.insert( + "sampler".to_string(), + serde_json::Value::String(env!("CARGO_PKG_NAME").to_string()), + ); + attrs.insert( + "sampler_version".to_string(), + serde_json::Value::String(env!("CARGO_PKG_VERSION").to_string()), + ); + attrs.insert( + "sampler_settings".to_string(), + serde_json::to_value(settings).context("Could not serialize sampler settings")?, + ); + root.store_metadata()?; + + GroupBuilder::new() + .build(store.clone(), &format!("{}warmup_posterior", group_path))? + .store_metadata()?; + GroupBuilder::new() + .build(store.clone(), &format!("{}warmup_sample_stats", group_path))? + .store_metadata()?; + GroupBuilder::new() + .build(store.clone(), &format!("{}posterior", group_path))? + .store_metadata()?; + GroupBuilder::new() + .build(store.clone(), &format!("{}sample_stats", group_path))? + .store_metadata()?; + + let warmup_param_arrays = create_arrays( + store.clone(), + &format!("{}warmup_sample_stats", group_path), + ¶m_types, + ¶m_dims, + n_chains, + n_tune, + &stat_dim_sizes, + self.draw_chunk_size, + )?; + for array in warmup_param_arrays.values() { + array.store_metadata()?; + } + let sample_param_arrays = create_arrays( + store.clone(), + &format!("{}sample_stats", group_path), + ¶m_types, + ¶m_dims, + n_chains, + n_draws, + &stat_dim_sizes, + self.draw_chunk_size, + )?; + for array in sample_param_arrays.values() { + array.store_metadata()?; + } + let warmup_draw_arrays = create_arrays( + store.clone(), + &format!("{}warmup_posterior", group_path), + &draw_types, + &draw_dims, + n_chains, + n_tune, + &draw_dim_sizes, + self.draw_chunk_size, + )?; + for array in warmup_draw_arrays.values() { + array.store_metadata()?; + } + let sample_draw_arrays = create_arrays( + store.clone(), + &format!("{}posterior", group_path), + &draw_types, + &draw_dims, + n_chains, + n_draws, + &draw_dim_sizes, + self.draw_chunk_size, + )?; + for array in sample_draw_arrays.values() { + array.store_metadata()?; + } + let trace_storage = ArrayCollection { + warmup_param_arrays, + sample_param_arrays, + warmup_draw_arrays, + sample_draw_arrays, + }; + + let draw_coords = math.coords(); + let stat_coords = settings.stat_coords(math); + + store_coords( + store.clone(), + format!("{}posterior", &group_path), + &draw_coords, + )?; + store_coords( + store.clone(), + format!("{}warmup_posterior", &group_path), + &draw_coords, + )?; + store_coords( + store.clone(), + format!("{}sample_stats", &group_path), + &stat_coords, + )?; + store_coords( + store.clone(), + format!("{}warmup_sample_stats", &group_path), + &stat_coords, + )?; + + Ok(ZarrTraceStorage { + arrays: Arc::new(trace_storage), + param_types, + draw_types, + draw_chunk_size, + }) + } +} + +impl TraceStorage for ZarrTraceStorage { + type ChainStorage = ZarrChainStorage; + + type Finalized = (); + + fn initialize_trace_for_chain(&self, chain_id: u64) -> Result { + Ok(ZarrChainStorage::new( + self.arrays.clone(), + &self.param_types, + &self.draw_types, + self.draw_chunk_size, + chain_id as _, + )) + } + + fn finalize( + self, + traces: Vec::Finalized>>, + ) -> Result<(Option, Self::Finalized)> { + for trace in traces { + if let Err(err) = trace { + return Ok((Some(err), ())); + } + } + Ok((None, ())) + } +} diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index b1fa2b0..6360ab1 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -1,22 +1,23 @@ -use arrow::array::StructArray; +use nuts_derive::Storable; +use serde::Serialize; use crate::adapt_strategy::CombinedCollector; use crate::chain::AdaptStrategy; use crate::hamiltonian::{Hamiltonian, Point}; use crate::nuts::{Collector, NutsOptions, SampleInfo}; -use crate::sampler_stats::{SamplerStats, StatTraceBuilder}; +use crate::sampler_stats::SamplerStats; use crate::state::State; use crate::stepsize::AcceptanceRateCollector; -use crate::stepsize_adapt::{StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy}; +use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy}; use crate::transformed_hamiltonian::TransformedHamiltonian; -use crate::{DualAverageSettings, Math, NutsError, Settings}; +use crate::{Math, NutsError}; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Serialize)] pub struct TransformedSettings { pub step_size_window: f64, pub transform_update_freq: u64, pub use_orbit_for_training: bool, - pub dual_average_options: DualAverageSettings, + pub step_size_settings: StepSizeSettings, pub transform_train_max_energy_error: f64, } @@ -27,7 +28,7 @@ impl Default for TransformedSettings { transform_update_freq: 128, use_orbit_for_training: false, transform_train_max_energy_error: 20f64, - dual_average_options: Default::default(), + step_size_settings: Default::default(), } } } @@ -41,39 +42,15 @@ pub struct TransformAdaptation { chain: u64, } -pub struct Builder { - step_size: StepSizeStatsBuilder, -} - -impl StatTraceBuilder for Builder { - fn append_value(&mut self, math: Option<&mut M>, value: &TransformAdaptation) { - let Self { step_size } = self; - step_size.append_value(math, &value.step_size); - } - - fn finalize(self) -> Option { - let Self { step_size } = self; - >::finalize(step_size) - } - - fn inspect(&self) -> Option { - let Self { step_size } = self; - >::inspect(step_size) - } -} +#[derive(Debug, Storable)] +pub struct Stats {} impl SamplerStats for TransformAdaptation { - type Builder = Builder; - type StatOptions = (); + type Stats = Stats; + type StatsOptions = (); - fn new_builder( - &self, - _stat_options: Self::StatOptions, - settings: &impl Settings, - dim: usize, - ) -> Self::Builder { - let step_size = SamplerStats::::new_builder(&self.step_size, (), settings, dim); - Builder { step_size } + fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { + Stats {} } } @@ -172,7 +149,7 @@ impl AdaptStrategy for TransformAdaptation { type Options = TransformedSettings; fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self { - let step_size = StepSizeStrategy::new(options.dual_average_options); + let step_size = StepSizeStrategy::new(options.step_size_settings); let final_window_size = ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64; Self { @@ -212,13 +189,15 @@ impl AdaptStrategy for TransformAdaptation { self.step_size.update(&collector.collector1); if draw >= self.num_tune { + // Needed for step size jitter + self.step_size.update_stepsize(rng, hamiltonian, true); self.tuning = false; return Ok(()); } if draw < self.final_window_size { if draw < 100 { - if (draw > 0) & (draw % 10 == 0) { + if (draw > 0) & draw.is_multiple_of(10) { hamiltonian.update_params( math, rng, @@ -227,7 +206,7 @@ impl AdaptStrategy for TransformAdaptation { collector.collector2.logps.iter(), )?; } - } else if (draw > 0) & (draw % self.options.transform_update_freq == 0) { + } else if (draw > 0) & draw.is_multiple_of(self.options.transform_update_freq) { hamiltonian.update_params( math, rng, @@ -237,13 +216,13 @@ impl AdaptStrategy for TransformAdaptation { )?; } self.step_size.update_estimator_early(); - self.step_size.update_stepsize(hamiltonian, false); + self.step_size.update_stepsize(rng, hamiltonian, false); return Ok(()); } self.step_size.update_estimator_late(); let is_last = draw == self.num_tune - 1; - self.step_size.update_stepsize(hamiltonian, is_last); + self.step_size.update_stepsize(rng, hamiltonian, is_last); Ok(()) } diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 93581a7..7b97482 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -1,18 +1,12 @@ use std::{marker::PhantomData, sync::Arc}; -use arrow::{ - array::{ - ArrayBuilder, FixedSizeListBuilder, Float64Builder, Int64Builder, PrimitiveBuilder, - StructArray, - }, - datatypes::{DataType, Field, Float64Type, Int64Type}, -}; +use nuts_derive::Storable; use crate::{ + DivergenceInfo, LogpError, Math, NutsError, hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, - sampler_stats::{SamplerStats, StatTraceBuilder}, + sampler_stats::SamplerStats, state::{State, StatePool}, - DivergenceInfo, LogpError, Math, NutsError, Settings, }; pub struct TransformedPoint { @@ -29,121 +23,12 @@ pub struct TransformedPoint { transform_id: i64, } -pub struct TransformedPointStatsBuilder { - fisher_distance: Float64Builder, - transformed_position: Option>>, - transformed_gradient: Option>>, - transformation_index: PrimitiveBuilder, -} - -impl StatTraceBuilder> for TransformedPointStatsBuilder { - fn append_value(&mut self, math: Option<&mut M>, value: &TransformedPoint) { - let math = math.expect("Transformed point stats need math instance"); - let Self { - fisher_distance, - transformed_position, - transformed_gradient, - transformation_index, - } = self; - - fisher_distance.append_value( - math.sq_norm_sum(&value.transformed_position, &value.transformed_gradient), - ); - transformation_index.append_value(value.transform_id); - - if let Some(store) = transformed_position { - store - .values() - .append_slice(&math.box_array(&value.transformed_position)); - store.append(true); - } - if let Some(store) = transformed_gradient { - store - .values() - .append_slice(&math.box_array(&value.transformed_gradient)); - store.append(true); - } - } - - fn finalize(self) -> Option { - let Self { - mut fisher_distance, - transformed_position, - transformed_gradient, - mut transformation_index, - } = self; - - let mut fields = vec![ - Field::new("fisher_distance", DataType::Float64, false), - Field::new("transformation_index", DataType::Int64, false), - ]; - let mut arrays = vec![ - ArrayBuilder::finish(&mut fisher_distance), - ArrayBuilder::finish(&mut transformation_index), - ]; - - if let Some(mut transformed_position) = transformed_position { - let array = ArrayBuilder::finish(&mut transformed_position); - fields.push(Field::new( - "transformed_position", - array.data_type().clone(), - true, - )); - arrays.push(array); - } - - if let Some(mut transformed_gradient) = transformed_gradient { - let array = ArrayBuilder::finish(&mut transformed_gradient); - fields.push(Field::new( - "transformed_gradient", - array.data_type().clone(), - true, - )); - arrays.push(array); - } - - Some(StructArray::new(fields.into(), arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { - fisher_distance, - transformed_position, - transformed_gradient, - transformation_index, - } = self; - - let mut fields = vec![ - Field::new("fisher_distance", DataType::Float64, false), - Field::new("transformation_index", DataType::Int64, false), - ]; - let mut arrays = vec![ - ArrayBuilder::finish_cloned(fisher_distance), - ArrayBuilder::finish_cloned(transformation_index), - ]; - - if let Some(transformed_position) = transformed_position { - let array = ArrayBuilder::finish_cloned(transformed_position); - fields.push(Field::new( - "transformed_position", - array.data_type().clone(), - true, - )); - arrays.push(array); - } - - if let Some(transformed_gradient) = transformed_gradient { - let array = ArrayBuilder::finish_cloned(transformed_gradient); - fields.push(Field::new( - "transformed_gradient", - array.data_type().clone(), - true, - )); - arrays.push(array); - } - - Some(StructArray::new(fields.into(), arrays, None)) - } +#[derive(Debug, Storable)] +pub struct PointStats { + pub fisher_distance: f64, + pub transformed_position: Option>, + pub transformed_gradient: Option>, + pub transformation_index: i64, } #[derive(Debug, Clone, Copy)] @@ -152,30 +37,23 @@ pub struct TransformedPointStatsOptions { } impl SamplerStats for TransformedPoint { - type Builder = TransformedPointStatsBuilder; - type StatOptions = TransformedPointStatsOptions; - - fn new_builder( - &self, - stat_options: Self::StatOptions, - settings: &impl Settings, - dim: usize, - ) -> Self::Builder { - let count = settings.hint_num_tune() + settings.hint_num_draws(); + type Stats = PointStats; + type StatsOptions = TransformedPointStatsOptions; + fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats { let mut transformed_position = None; let mut transformed_gradient = None; - if stat_options.store_transformed { - let items = PrimitiveBuilder::new(); - transformed_position = Some(FixedSizeListBuilder::new(items, dim as _)); - let items = PrimitiveBuilder::new(); - transformed_gradient = Some(FixedSizeListBuilder::new(items, dim as _)); + if opt.store_transformed { + transformed_position = Some(math.box_array(&self.transformed_position)); + transformed_gradient = Some(math.box_array(&self.transformed_gradient)); } - TransformedPointStatsBuilder { - fisher_distance: Float64Builder::with_capacity(count), - transformation_index: Int64Builder::with_capacity(count), - transformed_gradient, - transformed_position, + let fisher_distance = + math.sq_norm_sum(&self.transformed_position, &self.transformed_gradient); + PointStats { + fisher_distance, + transformation_index: self.transform_id, + transformed_gradient: transformed_gradient.map(|x| x.into_vec()), + transformed_position: transformed_position.map(|x| x.into_vec()), } } } @@ -337,7 +215,7 @@ pub struct TransformedHamiltonian { ones: M::Vector, zeros: M::Vector, step_size: f64, - params: Option, + params: Option, max_energy_error: f64, _phantom: PhantomData, pool: StatePool>, @@ -401,49 +279,18 @@ impl TransformedHamiltonian { } } -pub struct Builder { - step_size: Float64Builder, -} - -impl StatTraceBuilder> for Builder { - fn append_value(&mut self, _math: Option<&mut M>, value: &TransformedHamiltonian) { - let Self { step_size } = self; - step_size.append_value(value.step_size); - } - - fn finalize(self) -> Option { - let Self { mut step_size } = self; - - let fields = vec![Field::new("step_size", DataType::Float64, false)]; - let arrays = vec![ArrayBuilder::finish(&mut step_size)]; - - Some(StructArray::new(fields.into(), arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { step_size } = self; - - let fields = vec![Field::new("step_size", DataType::Float64, false)]; - let arrays = vec![ArrayBuilder::finish_cloned(step_size)]; - - Some(StructArray::new(fields.into(), arrays, None)) - } +#[derive(Debug, Storable)] +pub struct HamiltonianStats { + pub step_size: f64, } impl SamplerStats for TransformedHamiltonian { - type Builder = Builder; - type StatOptions = (); + type Stats = HamiltonianStats; + type StatsOptions = (); - fn new_builder( - &self, - _stat_options: Self::StatOptions, - settings: &impl Settings, - _dim: usize, - ) -> Self::Builder { - Builder { - step_size: Float64Builder::with_capacity( - settings.hint_num_draws() + settings.hint_num_tune(), - ), + fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { + HamiltonianStats { + step_size: self.step_size, } } } diff --git a/tests/sample_normal.rs b/tests/sample_normal.rs index a6288bc..441b639 100644 --- a/tests/sample_normal.rs +++ b/tests/sample_normal.rs @@ -3,18 +3,20 @@ use std::{ time::{Duration, Instant}, }; -use arrow::{ - array::{Array, ArrayBuilder, FixedSizeListBuilder, PrimitiveBuilder}, - datatypes::Float64Type, -}; +use anyhow::Context; use nuts_rs::{ - CpuLogpFunc, CpuMath, DiagAdaptExpSettings, DiagGradNutsSettings, DrawStorage, - EuclideanAdaptOptions, LogpError, LowRankNutsSettings, Model, Sampler, SamplerWaitResult, - Settings, Trace, + CpuLogpFunc, CpuMath, DiagAdaptExpSettings, DiagGradNutsSettings, EuclideanAdaptOptions, + LogpError, LowRankNutsSettings, Model, Sampler, SamplerWaitResult, ZarrConfig, }; +use nuts_storable::HasDims; use rand::prelude::Rng; use rand_distr::{Distribution, StandardNormal}; use thiserror::Error; +use zarrs::{ + array::Array, + array_subset::ArraySubset, + storage::{ReadableListableStorageTraits, store::MemoryStore}, +}; struct NormalLogp<'a> { dim: usize, @@ -30,9 +32,19 @@ impl LogpError for NormalLogpError { } } +impl HasDims for NormalLogp<'_> { + fn dim_sizes(&self) -> std::collections::HashMap { + std::collections::HashMap::from([ + ("unconstrained_parameter".to_string(), self.dim as u64), + ("dim".to_string(), self.dim as u64), + ]) + } +} + impl<'a> CpuLogpFunc for NormalLogp<'a> { type LogpError = NormalLogpError; - type TransformParams = (); + type FlowParameters = (); + type ExpandedVector = Vec; fn dim(&self) -> usize { self.dim @@ -55,90 +67,15 @@ impl<'a> CpuLogpFunc for NormalLogp<'a> { Ok(logp) } - fn inv_transform_normalize( - &mut self, - _params: &Self::TransformParams, - _untransformed_position: &[f64], - _untransofrmed_gradient: &[f64], - _transformed_position: &mut [f64], - _transformed_gradient: &mut [f64], - ) -> Result { - todo!() - } - - fn init_from_transformed_position( - &mut self, - _params: &Self::TransformParams, - _untransformed_position: &mut [f64], - _untransformed_gradient: &mut [f64], - _transformed_position: &[f64], - _transformed_gradient: &mut [f64], - ) -> Result<(f64, f64), Self::LogpError> { - todo!() - } - - fn init_from_untransformed_position( - &mut self, - _params: &Self::TransformParams, - _untransformed_position: &[f64], - _untransformed_gradient: &mut [f64], - _transformed_position: &mut [f64], - _transformed_gradient: &mut [f64], - ) -> Result<(f64, f64), Self::LogpError> { - todo!() - } - - fn update_transformation<'b, R: rand::Rng + ?Sized>( - &'b mut self, - _rng: &mut R, - _untransformed_positions: impl Iterator, - _untransformed_gradients: impl Iterator, - _logps: impl Iterator, - _params: &'b mut Self::TransformParams, - ) -> Result<(), Self::LogpError> { - todo!() - } - - fn new_transformation( + fn expand_vector( &mut self, _rng: &mut R, - _untransformed_position: &[f64], - _untransfogmed_gradient: &[f64], - _chain: u64, - ) -> Result { - todo!() - } - - fn transformation_id(&self, _params: &Self::TransformParams) -> Result { - todo!() - } -} - -struct Storage { - draws: FixedSizeListBuilder>, -} - -impl Storage { - fn new(size: usize) -> Storage { - let values = PrimitiveBuilder::new(); - let draws = FixedSizeListBuilder::new(values, size as i32); - Storage { draws } - } -} - -impl DrawStorage for Storage { - fn append_value(&mut self, point: &[f64]) -> anyhow::Result<()> { - self.draws.values().append_slice(point); - self.draws.append(true); - Ok(()) - } - - fn finalize(mut self) -> anyhow::Result> { - Ok(ArrayBuilder::finish(&mut self.draws)) - } - - fn inspect(&self) -> anyhow::Result> { - Ok(ArrayBuilder::finish_cloned(&self.draws)) + array: &[f64], + ) -> Result + where + R: rand::Rng + ?Sized, + { + Ok(array.to_vec()) } } @@ -158,21 +95,7 @@ impl Model for NormalModel { where Self: 'model; - type DrawStorage<'model, S: Settings> - = Storage - where - Self: 'model; - - fn new_trace<'model, S: Settings, R: Rng + ?Sized>( - &'model self, - _rng: &mut R, - _chain_id: u64, - _settings: &'model S, - ) -> anyhow::Result> { - Ok(Storage::new(self.mu.len())) - } - - fn math(&self) -> anyhow::Result> { + fn math(&self, _rng: &mut R) -> anyhow::Result> { Ok(CpuMath::new(NormalLogp { dim: self.mu.len(), mu: &self.mu, @@ -190,7 +113,7 @@ impl Model for NormalModel { } } -fn sample() -> anyhow::Result { +fn sample() -> anyhow::Result> { let mu = vec![0.5; 100]; let model = NormalModel::new(mu.into()); let settings = DiagGradNutsSettings { @@ -199,19 +122,21 @@ fn sample() -> anyhow::Result { ..Default::default() }; - let mut sampler = Sampler::new(model, settings, 6, None)?; + let store = Arc::new(MemoryStore::new()); + let trace_config = ZarrConfig::new(store.clone()); + let mut sampler = Sampler::new(model, settings, trace_config, 6, None)?; - let trace = loop { + let _ = loop { match sampler.wait_timeout(Duration::from_secs(1)) { SamplerWaitResult::Trace(trace) => break trace, SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler, SamplerWaitResult::Err(err, _trace) => return Err(err), }; }; - Ok(trace) + Ok(store) } -fn sample_debug_stats() -> anyhow::Result { +fn sample_debug_stats() -> anyhow::Result> { let mu = vec![0.5; 100]; let model = NormalModel::new(mu.into()); let settings = DiagGradNutsSettings { @@ -231,19 +156,43 @@ fn sample_debug_stats() -> anyhow::Result { ..Default::default() }; - let mut sampler = Sampler::new(model, settings, 6, None)?; + let store = Arc::new(MemoryStore::new()); + let trace_config = ZarrConfig::new(store.clone()); + let mut sampler = Sampler::new(model, settings, trace_config, 6, None)?; - let trace = loop { + let _ = loop { match sampler.wait_timeout(Duration::from_secs(1)) { SamplerWaitResult::Trace(trace) => break trace, SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler, SamplerWaitResult::Err(err, _trace) => return Err(err), }; }; - Ok(trace) + + let store_dyn: Arc = store.clone(); + + let diverging = Array::open(store_dyn.clone(), "/sample_stats/diverging") + .context("Could not read diverging array")?; + assert!( + diverging.dimension_names().as_ref().unwrap() + == &[Some("chain".to_string()), Some("draw".to_string())] + ); + + let _: Vec = diverging + .retrieve_array_subset_elements(&ArraySubset::new_with_shape(diverging.shape().to_vec()))?; + + let logp = Array::open(store_dyn.clone(), "/sample_stats/logp") + .context("Could not read logp array")?; + assert!( + logp.dimension_names().as_ref().unwrap() + == &[Some("chain".to_string()), Some("draw".to_string())] + ); + let _: Vec = + logp.retrieve_array_subset_elements(&ArraySubset::new_with_shape(logp.shape().to_vec()))?; + + Ok(store) } -fn sample_eigs_debug_stats() -> anyhow::Result { +fn sample_eigs_debug_stats() -> anyhow::Result> { let mu = vec![0.5; 10]; let model = NormalModel::new(mu.into()); let settings = LowRankNutsSettings { @@ -264,9 +213,11 @@ fn sample_eigs_debug_stats() -> anyhow::Result { ..Default::default() }; - let mut sampler = Sampler::new(model, settings, 6, None)?; + let store = Arc::new(MemoryStore::new()); + let trace_config = ZarrConfig::new(store.clone()); + let mut sampler = Sampler::new(model, settings, trace_config, 1, None)?; - let trace = loop { + let _trace = loop { match sampler.wait_timeout(Duration::from_secs(1)) { SamplerWaitResult::Trace(trace) => break trace, SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler, @@ -274,14 +225,13 @@ fn sample_eigs_debug_stats() -> anyhow::Result { }; }; - Ok(trace) + Ok(store) } #[test] fn run() -> anyhow::Result<()> { let start = Instant::now(); - let trace = sample()?; - assert!(trace.chains.len() == 6); + let _ = sample()?; dbg!(start.elapsed()); Ok(()) } @@ -289,8 +239,7 @@ fn run() -> anyhow::Result<()> { #[test] fn run_debug_stats() -> anyhow::Result<()> { let start = Instant::now(); - let trace = sample_debug_stats()?; - assert!(trace.chains.len() == 6); + let _ = sample_debug_stats()?; dbg!(start.elapsed()); Ok(()) } @@ -298,8 +247,7 @@ fn run_debug_stats() -> anyhow::Result<()> { #[test] fn run_debug_stats_eigs() -> anyhow::Result<()> { let start = Instant::now(); - let trace = sample_eigs_debug_stats()?; - assert!(trace.chains.len() == 1); + let _ = sample_eigs_debug_stats()?; dbg!(start.elapsed()); Ok(()) }