diff --git a/Cargo.toml b/Cargo.toml index 85769ef..63844eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,16 +24,20 @@ itertools = "0.14.0" thiserror = "2.0.3" rand_chacha = "0.9.0" anyhow = "1.0.72" -faer = { version = "0.22.6", default-features = false, features = ["linalg"] } +faer = { version = "0.23.2", default-features = false, features = ["linalg"] } pulp = "0.21.4" rayon = "1.10.0" -zarrs = { version = "0.21.0", features = [ +zarrs = { version = "0.22.0", features = [ "filesystem", "gzip", "sharding", "async", ], optional = true } ndarray = { version = "0.16.1", optional = true } +arrow = { version = "56.2.0", optional = true } +arrow-schema = { version = "56.2.0", features = [ + "canonical_extension_types", +], optional = true } nuts-derive = { path = "./nuts-derive" } nuts-storable = { path = "./nuts-storable" } serde = { version = "1.0.219", features = ["derive"] } @@ -50,13 +54,14 @@ equator = "0.4.2" serde_json = "1.0" ndarray = "0.16.1" tempfile = "3.0" -zarrs_object_store = "0.4.3" +zarrs_object_store = "0.5.0" object_store = "0.12.0" tokio = { version = "1.0", features = ["rt", "rt-multi-thread"] } [features] zarr = ["dep:zarrs", "dep:tokio"] ndarray = ["dep:ndarray"] +arrow = ["dep:arrow", "dep:arrow-schema"] [[bench]] name = "sample" diff --git a/examples/arrow_trace.rs b/examples/arrow_trace.rs new file mode 100644 index 0000000..5d92999 --- /dev/null +++ b/examples/arrow_trace.rs @@ -0,0 +1,343 @@ +//! Arrow backend example for MCMC trace storage +//! +//! This example demonstrates how to use the nuts-rs library with Arrow storage +//! for running MCMC sampling on a multivariate normal distribution. It shows: +//! +//! - Setting up a custom probability model +//! - Configuring Arrow storage for results +//! - Running multiple parallel chains +//! - Monitoring progress during sampling +//! - Accessing results in Arrow/Parquet format + +use std::{ + collections::HashMap, + f64, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use nuts_rs::{ + ArrowConfig, CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, + Sampler, SamplerWaitResult, Storable, +}; +use nuts_storable::{HasDims, Value}; +use rand::Rng; +use thiserror::Error; + +/// A multivariate normal distribution model +/// +/// This represents a probability distribution with mean μ and precision matrix P, +/// where the log probability is: logp(x) = -0.5 * (x - μ)^T * P * (x - μ) +#[derive(Clone, Debug)] +struct MultivariateNormal { + mean: Vec, + precision: Vec>, // Inverse of covariance matrix +} + +impl MultivariateNormal { + fn new(mean: Vec, precision: Vec>) -> Self { + Self { mean, precision } + } +} + +/// Custom error type for log probability calculations +/// +/// MCMC samplers need to distinguish between recoverable errors (like numerical +/// issues that can be handled by rejecting the proposal) and non-recoverable +/// errors (like programming bugs that should stop sampling). +#[allow(dead_code)] +#[derive(Debug, Error)] +enum MyLogpError { + #[error("Recoverable error in logp calculation: {0}")] + Recoverable(String), + #[error("Non-recoverable error in logp calculation: {0}")] + NonRecoverable(String), +} + +impl LogpError for MyLogpError { + fn is_recoverable(&self) -> bool { + matches!(self, MyLogpError::Recoverable(_)) + } +} + +/// Implementation of the log probability function for multivariate normal +/// +/// This struct contains the model parameters and implements the mathematical +/// operations needed for MCMC sampling: computing log probability and gradients. +#[derive(Clone)] +struct MvnLogp { + model: MultivariateNormal, + buffer: Vec, // Temporary buffer for computations +} + +impl HasDims for MvnLogp { + /// Define dimension names and sizes for storage + /// + /// This tells the storage system what array dimensions to expect. + /// These dimensions will be used to structure the output data using + /// Arrow's FixedShapeTensor extension type. + fn dim_sizes(&self) -> HashMap { + HashMap::from([ + // Dimension for the actual parameter vector x + ("x".to_string(), self.model.mean.len() as u64), + ]) + } + + fn coords(&self) -> HashMap { + HashMap::from([( + "x".to_string(), + Value::Strings(vec!["x1".to_string(), "x2".to_string()]), + )]) + } +} + +/// Additional quantities computed from each sample +/// +/// The `Storable` derive macro automatically generates code to store this +/// struct in the trace. The `dims` attribute specifies which dimension +/// each field should use. Multi-dimensional fields will be stored as +/// FixedShapeTensor extension types in Arrow format. +#[derive(Storable)] +struct ExpandedDraw { + /// Store the parameter values with dimension "x" + #[storable(dims("x"))] + prec: Vec, + /// A scalar derived quantity (difference between first two parameters) + diff: f64, +} + +impl CpuLogpFunc for MvnLogp { + type LogpError = MyLogpError; + type FlowParameters = (); // No parameter transformations needed + type ExpandedVector = ExpandedDraw; + + /// Return the dimensionality of the parameter space + fn dim(&self) -> usize { + self.model.mean.len() + } + + /// Compute log probability and gradient + /// + /// This is the core mathematical function that MCMC uses to explore + /// the parameter space. It computes both the log probability density + /// and its gradient for efficient sampling with Hamiltonian Monte Carlo. + fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result { + let n = x.len(); + + // Compute (x - mean) + let diff = &mut self.buffer; + for i in 0..n { + diff[i] = x[i] - self.model.mean[i]; + } + + let mut quad = 0.0; + + // Compute quadratic form and gradient: logp = -0.5 * diff^T * P * diff + for i in 0..n { + // Compute i-th component of P * diff + let mut pdot = 0.0; + for j in 0..n { + let pij = self.model.precision[i][j]; + pdot += pij * diff[j]; + quad += diff[i] * pij * diff[j]; + } + // Gradient of logp w.r.t. x_i: derivative of -0.5 * diff^T P diff is -(P * diff)_i + grad[i] = -pdot; + } + + Ok(-0.5 * quad) + } + + /// Compute additional quantities from each sample + /// + /// This function is called for each accepted sample to compute derived + /// quantities that should be stored in the trace. These might be + /// transformed parameters, predictions, or other quantities of interest. + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result { + // Store the raw parameter values and compute a simple derived quantity + Ok(ExpandedDraw { + prec: array.to_vec(), + diff: array[1] - array[0], // Example: difference between first two parameters + }) + } + + fn vector_coord(&self) -> Option { + Some(Value::Strings(vec!["x1".to_string(), "x2".to_string()])) + } +} + +/// The complete MCMC model +/// +/// This struct implements the Model trait, which is the main interface +/// that samplers use. It provides access to the mathematical operations +/// and handles initialization of the sampling chains. +struct MvnModel { + math: CpuMath, +} + +impl Model for MvnModel { + type Math<'model> + = CpuMath + where + Self: 'model; + + fn math(&self, _rng: &mut R) -> Result> { + Ok(self.math.clone()) + } + + /// Generate random initial positions for the chain + /// + /// Good initialization is important for MCMC efficiency. The starting + /// points should be in a reasonable region of the parameter space + /// where the log probability is finite. + fn init_position(&self, rng: &mut R, position: &mut [f64]) -> Result<()> { + // Initialize each parameter randomly in the range [-2, 2] + // For this simple example, this should put us in a reasonable + // region around the mode of the distribution + for p in position.iter_mut() { + *p = rng.random_range(-2.0..2.0); + } + Ok(()) + } +} + +fn main() -> Result<()> { + println!("=== Multivariate Normal MCMC with Arrow Storage ===\n"); + + // Create a 2D multivariate normal distribution + // This creates a distribution with mean [0, 0] and precision matrix + // [[1.0, 0.5], [0.5, 1.0]], which corresponds to some correlation + // between the two parameters + let mean = vec![0.0, 0.0]; + let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]]; + let mvn = MultivariateNormal::new(mean, precision); + + println!("Model: 2D Multivariate Normal"); + println!("Mean: {:?}", mvn.mean); + println!("Precision matrix: {:?}\n", mvn.precision); + + // Configure output location + let output_path = "mcmc_output"; + println!("Output will be saved to: {}/\n", output_path); + + // Sampling configuration + let num_chains = 4; // Run 4 parallel chains for better convergence assessment + let num_tune = 500; // Warmup samples to tune the sampler + let num_draws = 500; // Post-warmup samples to keep + + println!("Sampling configuration:"); + println!(" Chains: {}", num_chains); + println!(" Warmup samples: {}", num_tune); + println!(" Sampling draws: {}", num_draws); + + // Configure MCMC settings + // DiagGradNutsSettings provides sensible defaults for the NUTS sampler + let mut settings = DiagGradNutsSettings::default(); + settings.num_chains = num_chains as _; + settings.num_tune = num_tune; + settings.num_draws = num_draws as _; + settings.seed = 54; // For reproducible results + + // Create the model instance + let model = MvnModel { + math: CpuMath::new(MvnLogp { + model: mvn, + buffer: vec![0.0; 2], + }), + }; + + // Start sampling + println!("\nStarting MCMC sampling...\n"); + let start = Instant::now(); + + // Configure Arrow storage - it automatically determines capacity from settings + let arrow_config = ArrowConfig::new(); + + // Create sampler with 4 worker threads + // The sampler runs asynchronously, so we can monitor progress + let mut sampler = Some(Sampler::new(model, settings, arrow_config, 4, None)?); + + let mut num_progress_updates = 0; + + // Main sampling loop with progress monitoring + // This demonstrates how to monitor long-running sampling jobs + while let Some(sampler_) = sampler.take() { + match sampler_.wait_timeout(Duration::from_millis(50)) { + // Sampling completed successfully + SamplerWaitResult::Trace(traces) => { + println!("✓ Sampling completed in {:?}", start.elapsed()); + + // Display information about the resulting Arrow data + println!("✓ MCMC traces stored in Arrow format"); + println!("\nTrace summary:"); + println!(" Number of chains: {}", traces.len()); + + if let Some(first_trace) = traces.first() { + println!( + " Posterior samples: {} rows, {} columns", + first_trace.posterior.num_rows(), + first_trace.posterior.num_columns() + ); + println!( + " Sample stats: {} rows, {} columns", + first_trace.sample_stats.num_rows(), + first_trace.sample_stats.num_columns() + ); + + // Show column names + println!("\n Posterior columns:"); + for field in first_trace.posterior.schema().fields() { + println!( + " {} ({} {:?})", + field.name(), + field.data_type(), + field.metadata(), + ); + } + + println!("\n Sample stats columns:"); + for field in first_trace.sample_stats.schema().fields() { + println!( + " {} ({} {:?})", + field.name(), + field.data_type(), + field.metadata(), + ); + } + } + break; + } + + // Timeout - sampler is still running, show progress + SamplerWaitResult::Timeout(mut sampler_) => { + num_progress_updates += 1; + println!("Progress update {}:", num_progress_updates); + + // Get current progress from all chains + let progress = sampler_.progress()?; + for (i, chain) in progress.iter().enumerate() { + let phase = if chain.tuning { "warmup" } else { "sampling" }; + println!( + " Chain {}: {} samples ({} divergences), step size: {:.6} [{}]", + i, chain.finished_draws, chain.divergences, chain.step_size, phase + ); + } + println!(); // Add blank line for readability + + sampler = Some(sampler_); + } + + // An error occurred during sampling + SamplerWaitResult::Err(err, _) => { + eprintln!("✗ Sampling failed: {}", err); + return Err(err); + } + } + } + + Ok(()) +} diff --git a/nuts-derive/src/lib.rs b/nuts-derive/src/lib.rs index 5cd34e5..6452b92 100644 --- a/nuts-derive/src/lib.rs +++ b/nuts-derive/src/lib.rs @@ -79,17 +79,17 @@ enum StorableField { // Check if a type is a generic type parameter fn is_generic_param(ty: &Type, generics: &syn::Generics) -> bool { - if let Type::Path(type_path) = ty { - if type_path.path.segments.len() == 1 { - let type_name = &type_path.path.segments.first().unwrap().ident; - return generics.params.iter().any(|param| { - if let GenericParam::Type(type_param) = param { - &type_param.ident == type_name - } else { - false - } - }); - } + if let Type::Path(type_path) = ty + && type_path.path.segments.len() == 1 + { + let type_name = &type_path.path.segments.first().unwrap().ident; + return generics.params.iter().any(|param| { + if let GenericParam::Type(type_param) = param { + &type_param.ident == type_name + } else { + false + } + }); } false } @@ -97,16 +97,16 @@ fn is_generic_param(ty: &Type, generics: &syn::Generics) -> bool { // Check if a type implements Storable trait based on bounds fn has_storable_bound(ty: &Ident, generics: &syn::Generics) -> bool { for param in &generics.params { - if let GenericParam::Type(type_param) = param { - if &type_param.ident == ty { - for bound in &type_param.bounds { - if let syn::TypeParamBound::Trait(trait_bound) = bound { - let path = &trait_bound.path; - if path.segments.len() == 1 - && path.segments.first().unwrap().ident == "Storable" - { - return true; - } + if let GenericParam::Type(type_param) = param + && &type_param.ident == ty + { + for bound in &type_param.bounds { + if let syn::TypeParamBound::Trait(trait_bound) = bound { + let path = &trait_bound.path; + if path.segments.len() == 1 + && path.segments.first().unwrap().ident == "Storable" + { + return true; } } } @@ -165,7 +165,7 @@ pub fn storable_derive(input: TokenStream) -> TokenStream { ty_str ); }; - let item = if path.segments.first().unwrap().ident.to_string() == "Option" { + let item = if path.segments.first().unwrap().ident == "Option" { if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = &path.segments.first().unwrap().arguments @@ -514,7 +514,7 @@ pub fn storable_derive(input: TokenStream) -> TokenStream { let name = &field.name; if field.is_option { quote! { - if let Some(inner) = &self.#name { + if let Some(inner) = &mut self.#name { result.extend(inner.get_all(parent)); } } @@ -526,7 +526,7 @@ pub fn storable_derive(input: TokenStream) -> TokenStream { let name = &field.name; if field.is_option { quote! { - if let Some(inner) = &self.#name { + if let Some(inner) = &mut self.#name { result.push((#name.to_string().as_str(), Some(nuts_storable::Value::Generic(Box::new(inner.clone()))))); } else { result.push((#name.to_string().as_str(), None)); @@ -539,7 +539,7 @@ pub fn storable_derive(input: TokenStream) -> TokenStream { }); let get_all_fn = quote! { - fn get_all(&self, parent: &P) -> Vec<(&str, Option)> { + fn get_all<'a>(&'a mut self, parent: &'a P) -> Vec<(&'a str, Option)> { let mut result = Vec::with_capacity(Self::names(parent).len()); #(#get_all_exprs)* result diff --git a/nuts-derive/tests/storable.rs b/nuts-derive/tests/storable.rs index fa0bf34..1cadbf6 100644 --- a/nuts-derive/tests/storable.rs +++ b/nuts-derive/tests/storable.rs @@ -61,7 +61,7 @@ fn test_storable() { value2: 8.0, draws2: vec![9.0, 2.0, 3.0], }; - let stats = ExampleStats { + let mut stats = ExampleStats { step_size: 0.1, n_steps: 10, is_adapting: true, @@ -72,7 +72,7 @@ fn test_storable() { _not_stored: "should not be stored".to_string(), }; - let stats2: Example2 = Example2 { + let mut stats2: Example2 = Example2 { field1: 42, field2: stats.clone(), _phantom: std::marker::PhantomData, diff --git a/nuts-storable/src/lib.rs b/nuts-storable/src/lib.rs index 401f45e..032cb0c 100644 --- a/nuts-storable/src/lib.rs +++ b/nuts-storable/src/lib.rs @@ -89,107 +89,7 @@ pub trait Storable: Send + Sync { fn item_type(parent: &P, item: &str) -> ItemType; fn dims<'a>(parent: &'a P, item: &str) -> Vec<&'a str>; - fn get_all(&self, parent: &P) -> Vec<(&str, Option)>; - - fn get_f64(&self, parent: &P, name: &str) -> Option { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::ScalarF64(v)) => Some(v), - _ => None, - }) - } - - fn get_f32(&self, parent: &P, name: &str) -> Option { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::ScalarF32(v)) => Some(v), - _ => None, - }) - } - - fn get_u64(&self, parent: &P, name: &str) -> Option { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::ScalarU64(v)) => Some(v), - _ => None, - }) - } - - fn get_i64(&self, parent: &P, name: &str) -> Option { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::ScalarI64(v)) => Some(v), - _ => None, - }) - } - - fn get_bool(&self, parent: &P, name: &str) -> Option { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::ScalarBool(v)) => Some(v), - _ => None, - }) - } - - fn get_vec_f64(&self, parent: &P, name: &str) -> Option> { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::F64(v)) => Some(v), - _ => None, - }) - } - - fn get_vec_f32(&self, parent: &P, name: &str) -> Option> { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::F32(v)) => Some(v), - _ => None, - }) - } - - fn get_vec_u64(&self, parent: &P, name: &str) -> Option> { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::U64(v)) => Some(v), - _ => None, - }) - } - - fn get_vec_i64(&self, parent: &P, name: &str) -> Option> { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::I64(v)) => Some(v), - _ => None, - }) - } - - fn get_vec_bool(&self, parent: &P, name: &str) -> Option> { - self.get_all(parent) - .into_iter() - .find(|(item_name, _)| *item_name == name) - .and_then(|(_, value)| match value { - Some(Value::Bool(v)) => Some(v), - _ => None, - }) - } + fn get_all<'a>(&'a mut self, parent: &'a P) -> Vec<(&'a str, Option)>; } impl Storable

for Vec { @@ -205,7 +105,7 @@ impl Storable

for Vec { vec!["dim"] } - fn get_all(&self, _parent: &P) -> Vec<(&str, Option)> { + fn get_all<'a>(&'a mut self, _parent: &'a P) -> Vec<(&'a str, Option)> { vec![("value", Some(Value::F64(self.clone())))] } } @@ -223,7 +123,7 @@ impl Storable

for () { panic!("No items in unit type") } - fn get_all(&self, _parent: &P) -> Vec<(&str, Option)> { + fn get_all(&mut self, _parent: &P) -> Vec<(&str, Option)> { vec![] } } diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 79939e4..f687f23 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -5,9 +5,9 @@ use nuts_storable::{HasDims, Storable}; use rand::Rng; use serde::Serialize; -use super::mass_matrix::MassMatrixAdaptStrategy; use super::stepsize::AcceptanceRateCollector; use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy}; +use crate::mass_matrix::MassMatrixAdaptStrategy; use crate::{ NutsError, chain::AdaptStrategy, @@ -210,6 +210,7 @@ pub struct GlobalStrategyStats, M: Storable

> { pub step_size: S, #[storable(flatten)] pub mass_matrix: M, + pub tuning: bool, #[storable(ignore)] _phantom: std::marker::PhantomData P>, } @@ -243,6 +244,7 @@ where self.step_size.extract_stats(math, ()) }, mass_matrix: self.mass_matrix.extract_stats(math, opt.mass_matrix), + tuning: self.tuning, _phantom: PhantomData, } } diff --git a/src/chain.rs b/src/chain.rs index 026312a..222cc54 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -217,7 +217,7 @@ pub struct NutsStats, A: Storable

, D: Storable

> pub draw: u64, pub energy_error: f64, #[storable(dims("unconstrained_parameter"))] - pub unconstrained: Option>, + pub unconstrained_draw: Option>, #[storable(dims("unconstrained_parameter"))] pub gradient: Option>, #[storable(flatten)] @@ -289,7 +289,7 @@ impl> SamplerStats for NutsChain CpuMath { pub enum CpuMathError { #[error("Error during array operation")] ArrayError(), - #[error("Error during point expansion")] - ExpandError(), + #[error("Error during point expansion: {0}")] + ExpandError(String), } impl HasDims for CpuMath { @@ -57,7 +57,10 @@ impl Storable> for ExpandedVectorWrapper { F::ExpandedVector::dims(&parent.logp_func, item) } - fn get_all(&self, parent: &CpuMath) -> Vec<(&str, Option)> { + fn get_all<'a>( + &'a mut self, + parent: &'a CpuMath, + ) -> Vec<(&'a str, Option)> { self.0.get_all(&parent.logp_func) } } @@ -138,7 +141,9 @@ impl Math for CpuMath { rng, array .try_as_col_major() - .ok_or(CpuMathError::ExpandError())? + .ok_or_else(|| { + CpuMathError::ExpandError("Internal vector was not col major".into()) + })? .as_slice(), )?, )) diff --git a/src/lib.rs b/src/lib.rs index b722803..0a2a61b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -145,3 +145,6 @@ pub use storage::{CsvConfig, CsvTraceStorage}; pub use storage::{HashMapConfig, HashMapValue}; #[cfg(feature = "ndarray")] pub use storage::{NdarrayConfig, NdarrayTrace, NdarrayValue}; + +#[cfg(feature = "arrow")] +pub use storage::{ArrowConfig, ArrowTrace, ArrowTraceStorage}; diff --git a/src/mass_matrix/adapt.rs b/src/mass_matrix/adapt.rs index 28166ba..bb0d03b 100644 --- a/src/mass_matrix/adapt.rs +++ b/src/mass_matrix/adapt.rs @@ -4,7 +4,7 @@ use nuts_derive::Storable; use rand::Rng; use serde::Serialize; -use super::mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance}; +use super::diagonal::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance}; use crate::{ Math, NutsError, euclidean_hamiltonian::EuclideanPoint, diff --git a/src/mass_matrix/mass_matrix.rs b/src/mass_matrix/diagonal.rs similarity index 100% rename from src/mass_matrix/mass_matrix.rs rename to src/mass_matrix/diagonal.rs diff --git a/src/mass_matrix/low_rank.rs b/src/mass_matrix/low_rank.rs index 99f10cc..61ad9fc 100644 --- a/src/mass_matrix/low_rank.rs +++ b/src/mass_matrix/low_rank.rs @@ -1,4 +1,5 @@ use std::collections::VecDeque; +use std::iter::repeat; use faer::{Col, ColRef, Mat, MatRef, Scale}; use itertools::Itertools; @@ -6,7 +7,7 @@ use nuts_derive::Storable; use serde::Serialize; use super::adapt::MassMatrixAdaptStrategy; -use super::mass_matrix::{DrawGradCollector, MassMatrix}; +use super::diagonal::{DrawGradCollector, MassMatrix}; use crate::{ Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point, sampler_stats::SamplerStats, @@ -127,15 +128,17 @@ impl Default for LowRankSettings { Self { store_mass_matrix: false, gamma: 1e-5, - eigval_cutoff: 100f64, + eigval_cutoff: 2f64, } } } #[derive(Debug, Storable)] pub struct MatrixStats { - pub eigvals: Option>, - pub stds: Option>, + #[storable(dims("unconstrained_parameter"))] + pub mass_matrix_eigvals: Option>, + #[storable(dims("unconstrained_parameter"))] + pub mass_matrix_stds: Option>, pub num_eigenvalues: u64, } @@ -145,14 +148,18 @@ impl SamplerStats for LowRankMassMatrix { fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { if self.settings.store_mass_matrix { + let stds = Some(math.box_array(&self.stds)); let eigvals = self .inner .as_ref() .map(|inner| math.eigs_as_array(&inner.vals)); - let stds = Some(math.box_array(&self.stds)); + let mut eigvals = eigvals.map(|x| x.into_vec()); + if let Some(ref mut eigvals) = eigvals { + eigvals.extend(repeat(f64::NAN).take(stds.as_ref().unwrap().len() - eigvals.len())); + } MatrixStats { - eigvals: eigvals.map(|x| x.into_vec()), - stds: stds.map(|x| x.into_vec()), + mass_matrix_eigvals: eigvals, + mass_matrix_stds: stds.map(|x| x.into_vec()), num_eigenvalues: self .inner .as_ref() @@ -161,9 +168,13 @@ impl SamplerStats for LowRankMassMatrix { } } else { MatrixStats { - eigvals: None, - stds: None, - num_eigenvalues: 0, + mass_matrix_eigvals: None, + mass_matrix_stds: None, + num_eigenvalues: self + .inner + .as_ref() + .map(|inner| inner.num_eigenvalues) + .unwrap_or(0), } } } @@ -328,7 +339,7 @@ impl LowRankMassMatrixStrategy { fn rescale_points(draws: &mut Mat, grads: &mut Mat) -> Col { let (ndim, ndraws) = draws.shape(); - + Col::from_fn(ndim, |col| { let draw_mean = draws.row(col).sum() / (ndraws as f64); let grad_mean = grads.row(col).sum() / (ndraws as f64); diff --git a/src/mass_matrix/mod.rs b/src/mass_matrix/mod.rs index 8409350..015e134 100644 --- a/src/mass_matrix/mod.rs +++ b/src/mass_matrix/mod.rs @@ -1,10 +1,10 @@ mod adapt; +mod diagonal; mod low_rank; -mod mass_matrix; pub use adapt::DiagAdaptExpSettings; pub(crate) use adapt::MassMatrixAdaptStrategy; pub(crate) use adapt::Strategy; +pub(crate) use diagonal::{DiagMassMatrix, MassMatrix}; pub use low_rank::LowRankSettings; pub(crate) use low_rank::{LowRankMassMatrix, LowRankMassMatrixStrategy}; -pub(crate) use mass_matrix::{DiagMassMatrix, MassMatrix}; diff --git a/src/sampler.rs b/src/sampler.rs index 0ccb9ca..ebbaa6f 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -626,7 +626,7 @@ impl ChainProcess { let now = Instant::now(); //let (point, info) = sampler.draw().unwrap(); - let (_point, draw_data, stats, info) = sampler.expanded_draw().unwrap(); + let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap(); let mut guard = chain_trace .lock() @@ -680,7 +680,8 @@ impl ChainProcess { .lock() .map_err(|_| anyhow::anyhow!("Could not lock trace mutex")) .context("Could not flush trace")? - .as_mut().map(|v| v.flush()) + .as_mut() + .map(|v| v.flush()) .transpose()?; Ok(()) } @@ -692,11 +693,13 @@ enum SamplerCommand { Continue, Progress, Flush, + Inspect, } -enum SamplerResponse { +enum SamplerResponse { Ok(), Progress(Box<[ChainProgress]>), + Inspect(T), } pub enum SamplerWaitResult { @@ -708,7 +711,7 @@ pub enum SamplerWaitResult { pub struct Sampler { main_thread: JoinHandle, F)>>, commands: SyncSender, - responses: Receiver, + responses: Receiver, F)>>, results: Receiver>, } @@ -826,7 +829,11 @@ impl Sampler { pause_start = Instant::now(); } is_paused = true; - responses_tx.send(SamplerResponse::Ok())?; + responses_tx.send(SamplerResponse::Ok()).map_err(|e| { + anyhow::anyhow!( + "Could not send pause response to controller thread: {e}" + ) + })?; } Ok(SamplerCommand::Continue) => { for chain in chains.iter() { @@ -836,18 +843,50 @@ impl Sampler { } pause_time += pause_start.elapsed(); is_paused = false; - responses_tx.send(SamplerResponse::Ok())?; + responses_tx.send(SamplerResponse::Ok()).map_err(|e| { + anyhow::anyhow!( + "Could not send continue response to controller thread: {e}" + ) + })?; } Ok(SamplerCommand::Progress) => { let progress = chains.iter().map(|chain| chain.progress()).collect_vec(); - responses_tx.send(SamplerResponse::Progress(progress.into()))?; + responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| { + anyhow::anyhow!( + "Could not send progress response to controller thread: {e}" + ) + })?; + } + Ok(SamplerCommand::Inspect) => { + let traces = chains + .iter() + .map(|chain| { + chain + .trace + .lock() + .expect("Poisoned lock") + .as_ref() + .map(|v| v.inspect()) + }) + .flatten() + .collect_vec(); + let finalized_trace = trace.inspect(traces)?; + responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| { + anyhow::anyhow!( + "Could not send inspect response to controller thread: {e}" + ) + })?; } Ok(SamplerCommand::Flush) => { for chain in chains.iter() { chain.flush()?; } - responses_tx.send(SamplerResponse::Ok())?; + responses_tx.send(SamplerResponse::Ok()).map_err(|e| { + anyhow::anyhow!( + "Could not send flush response to controller thread: {e}" + ) + })?; } Err(RecvTimeoutError::Timeout) => {} Err(RecvTimeoutError::Disconnected) => { @@ -918,6 +957,18 @@ impl Sampler { Ok(()) } + pub fn inspect(&mut self) -> Result<(Option, F)> { + self.commands.send(SamplerCommand::Inspect)?; + let response = self + .responses + .recv() + .context("Could not recieve inspect response from controller thread")?; + let SamplerResponse::Inspect(trace) = response else { + bail!("Got invalid response from sample controller thread"); + }; + Ok(trace) + } + pub fn abort(self) -> Result<(Option, F)> { drop(self.commands); let result = self.main_thread.join(); diff --git a/src/state.rs b/src/state.rs index c680924..fac1893 100644 --- a/src/state.rs +++ b/src/state.rs @@ -104,10 +104,12 @@ impl> State { impl> Drop for State { fn drop(&mut self) { let rc = unsafe { std::mem::ManuallyDrop::take(&mut self.inner) }; - if (Rc::strong_count(&rc) == 1) & (Rc::weak_count(&rc) == 0) - && let Some(storage) = rc.reuser.upgrade() { - storage.free_states.borrow_mut().push(rc); - } + if (Rc::strong_count(&rc) == 1) + && (Rc::weak_count(&rc) == 0) + && let Some(storage) = rc.reuser.upgrade() + { + storage.free_states.borrow_mut().push(rc); + } } } diff --git a/src/storage/arrow.rs b/src/storage/arrow.rs new file mode 100644 index 0000000..c1aa571 --- /dev/null +++ b/src/storage/arrow.rs @@ -0,0 +1,640 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use arrow::array::{ + ArrayBuilder, ArrayRef, BooleanBuilder, Float32Builder, Float64Builder, Int64Builder, + LargeListBuilder, RecordBatch, RecordBatchOptions, StringBuilder, UInt64Builder, +}; +use arrow::datatypes::{DataType, Field, Schema}; +use nuts_storable::{ItemType, Value}; + +use crate::storage::{ChainStorage, StorageConfig, TraceStorage}; +use crate::{Math, Progress, Settings}; + +/// Container for different types of Arrow array builders +enum ArrowBuilder { + Tensor(LargeListBuilder>), + Scalar(Box), +} + +impl ArrowBuilder { + fn new(item_type: ItemType, capacity: usize, shape: Vec) -> Result { + let list_size = shape.iter().product::(); + let capacity = capacity + .checked_mul(list_size) + .ok_or_else(|| anyhow::anyhow!("Capacity overflow when creating ArrowBuilder"))?; + + let value_builder: Box = match item_type { + ItemType::F64 => Box::new(Float64Builder::with_capacity(capacity)), + ItemType::F32 => Box::new(Float32Builder::with_capacity(capacity)), + ItemType::Bool => Box::new(BooleanBuilder::with_capacity(capacity)), + ItemType::I64 => Box::new(Int64Builder::with_capacity(capacity)), + ItemType::U64 => Box::new(UInt64Builder::with_capacity(capacity)), + ItemType::String => Box::new(StringBuilder::with_capacity(capacity, capacity)), + }; + + if shape.is_empty() { + Ok(ArrowBuilder::Scalar(value_builder)) + } else { + let data_type = item_type_to_arrow_type(item_type); + let list_builder = LargeListBuilder::new(value_builder); + let list_builder = list_builder.with_field(Field::new("item", data_type, false)); + Ok(ArrowBuilder::Tensor(list_builder)) + } + } + + fn append_value(&mut self, value: Value) -> Result<()> { + macro_rules! downcast_builder { + ($builder:expr, $ty:ty, $variant:ident) => { + $builder + .as_any_mut() + .downcast_mut::<$ty>() + .ok_or_else(|| anyhow::anyhow!(concat!("Expected ", stringify!($ty)))) + }; + } + match self { + ArrowBuilder::Scalar(builder) => match value { + Value::ScalarF64(v) => { + downcast_builder!(builder, Float64Builder, ScalarF64)?.append_value(v); + } + Value::ScalarF32(v) => { + downcast_builder!(builder, Float32Builder, ScalarF32)?.append_value(v); + } + Value::ScalarBool(v) => { + downcast_builder!(builder, BooleanBuilder, ScalarBool)?.append_value(v); + } + Value::ScalarI64(v) => { + downcast_builder!(builder, Int64Builder, ScalarI64)?.append_value(v); + } + Value::ScalarU64(v) => { + downcast_builder!(builder, UInt64Builder, ScalarU64)?.append_value(v); + } + Value::ScalarString(v) => { + downcast_builder!(builder, StringBuilder, ScalarString)?.append_value(&v); + } + Value::U64(items) => { + assert!(items.len() == 1); + downcast_builder!(builder, UInt64Builder, U64)?.append_slice(items.as_slice()); + } + Value::I64(items) => { + assert!(items.len() == 1); + downcast_builder!(builder, Int64Builder, I64)?.append_slice(items.as_slice()); + } + Value::F64(items) => { + assert!(items.len() == 1); + downcast_builder!(builder, Float64Builder, F64)?.append_slice(items.as_slice()); + } + Value::F32(items) => { + assert!(items.len() == 1); + downcast_builder!(builder, Float32Builder, F32)?.append_slice(items.as_slice()); + } + Value::Bool(items) => { + assert!(items.len() == 1); + downcast_builder!(builder, BooleanBuilder, Bool)? + .append_slice(items.as_slice()); + } + Value::Strings(items) => { + let string_builder = downcast_builder!(builder, StringBuilder, Strings)?; + for item in items { + string_builder.append_value(&item); + } + } + }, + ArrowBuilder::Tensor(list_builder) => { + match value { + Value::F64(v) => { + downcast_builder!(list_builder.values(), Float64Builder, F64)? + .append_slice(v.as_slice()); + } + Value::F32(v) => { + downcast_builder!(list_builder.values(), Float32Builder, F32)? + .append_slice(v.as_slice()); + } + Value::I64(v) => { + downcast_builder!(list_builder.values(), Int64Builder, I64)? + .append_slice(v.as_slice()); + } + Value::U64(v) => { + downcast_builder!(list_builder.values(), UInt64Builder, U64)? + .append_slice(v.as_slice()); + } + Value::Bool(v) => { + downcast_builder!(list_builder.values(), BooleanBuilder, Bool)? + .append_slice(v.as_slice()); + } + Value::Strings(items) => { + let string_builder = + downcast_builder!(list_builder.values(), StringBuilder, Strings)?; + for item in items { + string_builder.append_value(&item); + } + } + Value::ScalarString(val) => { + downcast_builder!(list_builder.values(), StringBuilder, ScalarString)? + .append_value(val); + } + Value::ScalarU64(val) => { + downcast_builder!(list_builder.values(), UInt64Builder, ScalarU64)? + .append_value(val); + } + Value::ScalarI64(val) => { + downcast_builder!(list_builder.values(), Int64Builder, ScalarI64)? + .append_value(val); + } + Value::ScalarF64(val) => { + downcast_builder!(list_builder.values(), Float64Builder, ScalarF64)? + .append_value(val); + } + Value::ScalarF32(val) => { + downcast_builder!(list_builder.values(), Float32Builder, ScalarF32)? + .append_value(val); + } + Value::ScalarBool(val) => { + downcast_builder!(list_builder.values(), BooleanBuilder, ScalarBool)? + .append_value(val); + } + } + list_builder.append(true); + } + } + Ok(()) + } + + fn append_null(&mut self) -> Result<()> { + match self { + ArrowBuilder::Scalar(builder) => { + if let Some(builder) = builder.as_any_mut().downcast_mut::() { + builder.append_null(); + } else if let Some(builder) = builder.as_any_mut().downcast_mut::() + { + builder.append_null(); + } else if let Some(builder) = builder.as_any_mut().downcast_mut::() { + builder.append_null(); + } else if let Some(builder) = builder.as_any_mut().downcast_mut::() { + builder.append_null(); + } else if let Some(builder) = builder.as_any_mut().downcast_mut::() + { + builder.append_null(); + } else if let Some(builder) = builder.as_any_mut().downcast_mut::() { + builder.append_null(); + } else { + return Err(anyhow::anyhow!("Unknown builder type for null")); + } + } + ArrowBuilder::Tensor(builder) => builder.append(false), + } + Ok(()) + } + + fn finish(&mut self) -> ArrayRef { + match self { + ArrowBuilder::Scalar(builder) => Arc::new(builder.finish()), + ArrowBuilder::Tensor(builder) => Arc::new(builder.finish()), + } + } + + fn finish_cloned(&self) -> ArrayRef { + match self { + ArrowBuilder::Scalar(builder) => Arc::new(builder.finish_cloned()), + ArrowBuilder::Tensor(builder) => Arc::new(builder.finish_cloned()), + } + } +} + +/// Convert ItemType to Arrow DataType +fn item_type_to_arrow_type(item_type: ItemType) -> DataType { + match item_type { + ItemType::F64 => DataType::Float64, + ItemType::F32 => DataType::Float32, + ItemType::U64 => DataType::UInt64, + ItemType::I64 => DataType::Int64, + ItemType::Bool => DataType::Boolean, + ItemType::String => DataType::Utf8, + } +} + +/// Create a field with tensor extension type if shape is provided +fn create_field_with_shape( + name: &str, + item_type: ItemType, + dims: &Vec, + dim_sizes: &HashMap, +) -> Result { + let arrow_type = item_type_to_arrow_type(item_type); + + if !dims.is_empty() { + // Multi-dimensional tensor + let metadata = HashMap::from([ + ( + "dims".to_string(), + dims.iter().cloned().collect::>().join(","), + ), + ( + "shape".to_string(), + dims.iter() + .map(|dim| { + dim_sizes + .get(dim) + .copied() + .map(|size| size.to_string()) + .expect("Dimension size not found") + }) + .collect::>() + .join(","), + ), + ]); + + let inner_field = Field::new("item", arrow_type, false); + let field = Field::new_large_list(name, inner_field, true); + let field = field.with_metadata(metadata); + Ok(field) + } else { + Ok(Field::new(name, arrow_type, true)) + } +} +/// Main storage for Arrow MCMC traces +pub struct ArrowTraceStorage { + stat_types: Vec<(String, ItemType)>, + draw_types: Vec<(String, ItemType)>, + stat_dims: Vec<(String, Vec)>, + draw_dims: Vec<(String, Vec)>, + stat_dim_sizes: HashMap, + draw_dim_sizes: HashMap, + expected_draws: usize, +} + +/// Per-chain storage for Arrow MCMC traces +pub struct ArrowChainStorage { + draw_builders: Vec<(String, ArrowBuilder)>, + stats_builders: Vec<(String, ArrowBuilder)>, + stat_types: Vec<(String, ItemType)>, + draw_types: Vec<(String, ItemType)>, + stats_dims: Vec<(String, Vec)>, + draw_dims: Vec<(String, Vec)>, + stat_dim_sizes: HashMap, + draw_dim_sizes: HashMap, + draw_count: usize, +} + +/// Final result containing Arrow record batches +#[derive(Clone, Debug)] +pub struct ArrowTrace { + pub posterior: RecordBatch, + pub sample_stats: RecordBatch, +} + +impl ArrowChainStorage { + fn new( + expected_draws: usize, + stat_types: &[(String, ItemType)], + draw_types: &[(String, ItemType)], + stat_dims: &[(String, Vec)], + draw_dims: &[(String, Vec)], + stat_dim_sizes: &HashMap, + draw_dim_sizes: &HashMap, + ) -> Result { + let draw_builders = draw_types + .iter() + .zip(draw_dims.iter()) + .map(|((name, item_type), (name2, dims))| { + assert_eq!( + name, name2, + "Draw types and dims must have matching names and order" + ); + let shape = dims + .iter() + .map(|dim| { + draw_dim_sizes + .get(dim) + .copied() + .map(|x| x as usize) + .ok_or_else(|| { + anyhow::anyhow!("Unknown dimension size for dimension {}", dim) + }) + }) + .collect::>>()?; + Ok(( + name.clone(), + ArrowBuilder::new(*item_type, expected_draws, shape)?, + )) + }) + .collect::>>()?; + + let stats_builders = stat_types + .iter() + .zip(stat_dims.iter()) + .map(|((name, item_type), (name2, dims))| { + assert_eq!( + name, name2, + "Draw types and dims must have matching names and order" + ); + let shape = dims + .iter() + .map(|dim| { + stat_dim_sizes + .get(dim) + .copied() + .map(|x| x as usize) + .ok_or_else(|| { + anyhow::anyhow!("Unknown dimension size for dimension {}", dim) + }) + }) + .collect::>>()?; + Ok(( + name.clone(), + ArrowBuilder::new(*item_type, expected_draws, shape)?, + )) + }) + .collect::>>()?; + + Ok(Self { + draw_builders, + stats_builders, + stat_types: stat_types.to_vec(), + draw_types: draw_types.to_vec(), + stats_dims: stat_dims.to_vec(), + draw_dims: draw_dims.to_vec(), + stat_dim_sizes: stat_dim_sizes.clone(), + draw_dim_sizes: draw_dim_sizes.clone(), + draw_count: 0, + }) + } + + fn finalize_builders(mut self) -> Result { + // Create posterior schema and arrays + + let posterior_fields = self + .draw_types + .iter() + .zip(self.draw_dims.iter()) + .map(|((name, item_type), (_, dims))| { + create_field_with_shape(name, *item_type, dims, &self.draw_dim_sizes) + }) + .collect::>>()?; + + let posterior_arrays: Vec = self + .draw_builders + .iter_mut() + .map(|(_, builder)| builder.finish()) + .collect(); + + let posterior_schema = Schema::new(posterior_fields); + let posterior_options = RecordBatchOptions::new().with_row_count(Some(self.draw_count)); + let posterior = RecordBatch::try_new_with_options( + Arc::new(posterior_schema), + posterior_arrays, + &posterior_options, + ) + .context("Could not convert posterior to RecordBatch")?; + + // Create stats schema and arrays + let stats_fields = self + .stat_types + .iter() + .zip(self.stats_dims.iter()) + .map(|((name, item_type), (_, dims))| { + create_field_with_shape(name, *item_type, dims, &self.stat_dim_sizes) + }) + .collect::>>()?; + + let stats_arrays: Vec = self + .stats_builders + .iter_mut() + .map(|(_, builder)| builder.finish()) + .collect(); + + let stats_schema = Schema::new(stats_fields); + let stats_options = RecordBatchOptions::new().with_row_count(Some(self.draw_count)); + let sample_stats = + RecordBatch::try_new_with_options(Arc::new(stats_schema), stats_arrays, &stats_options) + .context("Could not convert sample stats to RecordBatch")?; + + Ok(ArrowTrace { + posterior, + sample_stats, + }) + } +} + +impl ChainStorage for ArrowChainStorage { + type Finalized = ArrowTrace; + + fn record_sample( + &mut self, + _settings: &impl Settings, + stats: Vec<(&str, Option)>, + draws: Vec<(&str, Option)>, + _info: &Progress, + ) -> Result<()> { + stats + .into_iter() + .zip(self.stats_builders.iter_mut()) + .try_for_each(|((name, value), (expected_name, builder))| { + if name != expected_name { + panic!( + "Draw name mismatch: expected {}, got {}", + expected_name, name + ); + } + + if let Some(value) = value { + builder.append_value(value)?; + } else { + builder.append_null()?; + } + Ok::<_, anyhow::Error>(()) + })?; + + draws + .into_iter() + .zip(self.draw_builders.iter_mut()) + .try_for_each(|((name, value), (expected_name, builder))| { + if name != expected_name { + panic!( + "Draw name mismatch: expected {}, got {}", + expected_name, name + ); + } + + if let Some(value) = value { + builder.append_value(value)?; + } else { + builder.append_null()?; + } + Ok::<_, anyhow::Error>(()) + })?; + + self.draw_count += 1; + + Ok(()) + } + + fn finalize(self) -> Result { + self.finalize_builders() + } + + fn flush(&self) -> Result<()> { + // No-op for in-memory storage + Ok(()) + } + + fn inspect(&self) -> Result> { + let posterior_fields = self + .draw_types + .iter() + .zip(self.draw_dims.iter()) + .map(|((name, item_type), (_, dims))| { + create_field_with_shape(name, *item_type, dims, &self.draw_dim_sizes) + }) + .collect::>>()?; + + let posterior_arrays: Vec = self + .draw_builders + .iter() + .map(|(_, builder)| builder.finish_cloned()) + .collect(); + + let posterior_schema = Schema::new(posterior_fields); + let posterior_options = RecordBatchOptions::new().with_row_count(Some(self.draw_count)); + let posterior = RecordBatch::try_new_with_options( + Arc::new(posterior_schema), + posterior_arrays, + &posterior_options, + ) + .context("Could not convert posterior to RecordBatch")?; + + // Create stats schema and arrays + let stats_fields = self + .stat_types + .iter() + .zip(self.stats_dims.iter()) + .map(|((name, item_type), (_, dims))| { + create_field_with_shape(name, *item_type, dims, &self.stat_dim_sizes) + }) + .collect::>>()?; + + let stats_arrays: Vec = self + .stats_builders + .iter() + .map(|(_, builder)| builder.finish_cloned()) + .collect(); + + let stats_schema = Schema::new(stats_fields); + let stats_options = RecordBatchOptions::new().with_row_count(Some(self.draw_count)); + let sample_stats = + RecordBatch::try_new_with_options(Arc::new(stats_schema), stats_arrays, &stats_options) + .context("Could not convert sample stats to RecordBatch")?; + + Ok(Some(ArrowTrace { + posterior, + sample_stats, + })) + } +} + +/// Configuration for Arrow-based MCMC storage. +/// +/// This storage backend keeps all data in memory using Arrow's columnar format. +/// It's efficient for moderate-sized datasets and provides interoperability +/// with other Arrow-based tools. Multi-dimensional parameters +/// are stored as Arrow LargeList arrays with custom metadata containing +/// dimension names. +pub struct ArrowConfig {} + +impl ArrowConfig { + /// Create a new Arrow configuration. + pub fn new() -> Self { + Self {} + } +} + +impl Default for ArrowConfig { + fn default() -> Self { + Self::new() + } +} + +impl StorageConfig for ArrowConfig { + type Storage = ArrowTraceStorage; + + fn new_trace(self, settings: &impl Settings, math: &M) -> Result { + let stat_types = settings.stat_types(math); + let draw_types = settings.data_types(math); + let stat_dims = settings.stat_dims_all(math); + let draw_dims = settings.data_dims_all(math); + let stat_dim_sizes = settings.stat_dim_sizes(math); + let draw_dim_sizes = math.dim_sizes(); + + // Calculate expected total draws (warmup + sampling) + let expected_draws = (settings.hint_num_tune() + settings.hint_num_draws()) as usize; + + Ok(ArrowTraceStorage { + stat_types, + draw_types, + stat_dims, + draw_dims, + stat_dim_sizes, + draw_dim_sizes, + expected_draws, + }) + } +} + +impl TraceStorage for ArrowTraceStorage { + type ChainStorage = ArrowChainStorage; + type Finalized = Vec; + + fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result { + ArrowChainStorage::new( + self.expected_draws, + &self.stat_types, + &self.draw_types, + &self.stat_dims, + &self.draw_dims, + &self.stat_dim_sizes, + &self.draw_dim_sizes, + ) + } + + fn finalize( + self, + traces: Vec::Finalized>>, + ) -> Result<(Option, Self::Finalized)> { + let mut results = Vec::new(); + let mut first_error = None; + + for trace in traces { + match trace { + Ok(trace) => results.push(trace), + Err(err) => { + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + Ok((first_error, results)) + } + + fn inspect( + &self, + traces: Vec::Finalized>>>, + ) -> Result<(Option, Self::Finalized)> { + let mut results = Vec::new(); + let mut first_error = None; + + for trace in traces { + match trace { + Ok(Some(trace)) => results.push(trace), + Ok(None) => {} + Err(err) => { + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + Ok((first_error, results)) + } +} diff --git a/src/storage/storage.rs b/src/storage/core.rs similarity index 89% rename from src/storage/storage.rs rename to src/storage/core.rs index 5032c1e..8afaaf0 100644 --- a/src/storage/storage.rs +++ b/src/storage/core.rs @@ -23,6 +23,10 @@ pub trait ChainStorage: Send { /// Finalizes the storage and returns processed results. fn finalize(self) -> Result; + fn inspect(&self) -> Result> { + Ok(None) + } + /// Flush any buffered data to ensure all samples are stored. fn flush(&self) -> Result<()>; } @@ -63,4 +67,9 @@ pub trait TraceStorage: Send + Sync + Sized + 'static { self, traces: Vec::Finalized>>, ) -> Result<(Option, Self::Finalized)>; + + fn inspect( + &self, + traces: Vec::Finalized>>>, + ) -> Result<(Option, Self::Finalized)>; } diff --git a/src/storage/csv.rs b/src/storage/csv.rs index 3773223..005e525 100644 --- a/src/storage/csv.rs +++ b/src/storage/csv.rs @@ -348,6 +348,12 @@ impl ChainStorage for CsvChainStorage { // In practice, the buffer will be flushed when the file is closed Ok(()) } + + fn inspect(&self) -> Result> { + // For CSV storage, inspection does not produce a finalized result + self.flush()?; + Ok(None) + } } impl StorageConfig for CsvConfig { @@ -599,6 +605,19 @@ impl TraceStorage for CsvTraceStorage { } Ok((None, ())) } + + fn inspect( + &self, + traces: Vec::Finalized>>>, + ) -> Result<(Option, Self::Finalized)> { + // Check for any errors in the chain inspections + for trace_result in traces { + if let Err(err) = trace_result { + return Ok((Some(err), ())); + } + } + Ok((None, ())) + } } #[cfg(test)] diff --git a/src/storage/hashmap.rs b/src/storage/hashmap.rs index a5c6e2b..204b8d7 100644 --- a/src/storage/hashmap.rs +++ b/src/storage/hashmap.rs @@ -51,12 +51,14 @@ impl HashMapValue { } /// Main storage for HashMap MCMC traces +#[derive(Clone)] pub struct HashMapTraceStorage { draw_types: Vec<(String, ItemType)>, param_types: Vec<(String, ItemType)>, } /// Per-chain storage for HashMap MCMC traces +#[derive(Clone)] pub struct HashMapChainStorage { warmup_stats: HashMap, sample_stats: HashMap, @@ -251,6 +253,10 @@ impl ChainStorage for HashMapChainStorage { fn flush(&self) -> Result<()> { Ok(()) } + + fn inspect(&self) -> Result> { + self.clone().finalize().map(Some) + } } pub struct HashMapConfig {} @@ -314,4 +320,12 @@ impl TraceStorage for HashMapTraceStorage { Ok((first_error, results)) } + + fn inspect( + &self, + traces: Vec::Finalized>>>, + ) -> Result<(Option, Self::Finalized)> { + self.clone() + .finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect()) + } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index d8c04c3..fdc3141 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,11 +1,15 @@ +#[cfg(feature = "arrow")] +mod arrow; +mod core; mod csv; mod hashmap; #[cfg(feature = "ndarray")] mod ndarray; -mod storage; #[cfg(feature = "zarr")] mod zarr; +#[cfg(feature = "arrow")] +pub use arrow::{ArrowConfig, ArrowTrace, ArrowTraceStorage}; #[cfg(feature = "zarr")] pub use zarr::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage}; @@ -14,4 +18,4 @@ pub use hashmap::{HashMapConfig, HashMapValue}; #[cfg(feature = "ndarray")] pub use ndarray::{NdarrayConfig, NdarrayTrace, NdarrayValue}; -pub use storage::{ChainStorage, StorageConfig, TraceStorage}; +pub use core::{ChainStorage, StorageConfig, TraceStorage}; diff --git a/src/storage/ndarray.rs b/src/storage/ndarray.rs index c02dfff..39e2d93 100644 --- a/src/storage/ndarray.rs +++ b/src/storage/ndarray.rs @@ -136,6 +136,7 @@ struct SharedArrays { } /// Main storage for ndarray MCMC traces +#[derive(Clone)] pub struct NdarrayTraceStorage { shared_arrays: Arc>, } @@ -192,6 +193,12 @@ impl NdarrayChainStorage { pub struct NdarrayConfig {} +impl Default for NdarrayConfig { + fn default() -> Self { + Self::new() + } +} + impl NdarrayConfig { pub fn new() -> Self { Self {} @@ -322,10 +329,10 @@ impl TraceStorage for NdarrayTraceStorage { let mut first_error = None; for trace in traces { - if let Err(err) = trace { - if first_error.is_none() { - first_error = Some(err); - } + if let Err(err) = trace + && first_error.is_none() + { + first_error = Some(err); } } @@ -342,4 +349,20 @@ impl TraceStorage for NdarrayTraceStorage { Ok((first_error, result)) } + + fn inspect( + &self, + traces: Vec::Finalized>>>, + ) -> Result<(Option, Self::Finalized)> { + self.clone().finalize( + traces + .into_iter() + .map(|res| match res { + Ok(Some(_)) => Ok(()), + Ok(None) => Ok(()), + Err(err) => Err(err), + }) + .collect(), + ) + } } diff --git a/src/storage/zarr/async_impl.rs b/src/storage/zarr/async_impl.rs index acde6de..a361aec 100644 --- a/src/storage/zarr/async_impl.rs +++ b/src/storage/zarr/async_impl.rs @@ -123,13 +123,7 @@ fn store_zarr_chunk_sync( chain_chunk_index: u64, ) -> Result<()> { let array = array.clone(); - handle.block_on(async move { - tokio::runtime::Handle::current().block_on(store_zarr_chunk_async( - array, - data, - chain_chunk_index, - )) - }) + handle.block_on(async move { store_zarr_chunk_async(array, data, chain_chunk_index).await }) } /// Store coordinates in zarr arrays @@ -149,14 +143,10 @@ async fn store_coords( _ => panic!("Unsupported coordinate type for {}", name), }; let name: &String = name; - let coord_array = ArrayBuilder::new( - vec![len as u64], - data_type, - vec![len as u64].try_into().expect("Invalid chunk size"), - fill_value, - ) - .dimension_names(Some(vec![name.to_string()])) - .build(store.clone(), &format!("{}/{}", group, name))?; + let coord_array = + ArrayBuilder::new(vec![len as u64], vec![len as u64], data_type, fill_value) + .dimension_names(Some(vec![name.to_string()])) + .build(store.clone(), &format!("{}/{}", group, name))?; let subset = vec![0]; match coord { &Value::F64(ref v) => { @@ -654,4 +644,16 @@ impl TraceStorage for ZarrAsyncTraceStorage { } Ok((None, ())) } + + fn inspect( + &self, + traces: Vec::Finalized>>>, + ) -> Result<(Option, Self::Finalized)> { + for trace in traces { + if let Err(err) = trace { + return Ok((Some(err), ())); + }; + } + Ok((None, ())) + } } diff --git a/src/storage/zarr/common.rs b/src/storage/zarr/common.rs index 734e5b0..5e5d222 100644 --- a/src/storage/zarr/common.rs +++ b/src/storage/zarr/common.rs @@ -134,11 +134,11 @@ impl SampleBuffer { (SampleBufferValue::U64(vec), Value::ScalarU64(v)) => vec.push(v), (SampleBufferValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v), (SampleBufferValue::I64(vec), Value::ScalarI64(v)) => vec.push(v), - (SampleBufferValue::F64(vec), Value::F64(v)) => vec.extend(v.into_iter()), - (SampleBufferValue::F32(vec), Value::F32(v)) => vec.extend(v.into_iter()), - (SampleBufferValue::U64(vec), Value::U64(v)) => vec.extend(v.into_iter()), - (SampleBufferValue::Bool(vec), Value::Bool(v)) => vec.extend(v.into_iter()), - (SampleBufferValue::I64(vec), Value::I64(v)) => vec.extend(v.into_iter()), + (SampleBufferValue::F64(vec), Value::F64(v)) => vec.extend(v), + (SampleBufferValue::F32(vec), Value::F32(v)) => vec.extend(v), + (SampleBufferValue::U64(vec), Value::U64(v)) => vec.extend(v), + (SampleBufferValue::Bool(vec), Value::Bool(v)) => vec.extend(v), + (SampleBufferValue::I64(vec), Value::I64(v)) => vec.extend(v), _ => panic!("Mismatched item type"), } self.len += 1; @@ -181,12 +181,12 @@ pub fn create_arrays( anyhow::anyhow!("Unknown dimension size for dimension {}", dim) .context(format!("Could not write {}/{}", group_path, name)) }) - .map(|size| *size) + .copied() }) .collect(); let extra_shape = extra_shape?; - let shape: Vec = std::iter::once(n_chains as u64) - .chain(std::iter::once(n_draws as u64)) + let shape: Vec = std::iter::once(n_chains) + .chain(std::iter::once(n_draws)) .chain(extra_shape.clone()) .collect(); let zarr_type = match item_type { @@ -209,15 +209,9 @@ pub fn create_arrays( .chain(std::iter::once(draw_chunk_size)) .chain(extra_shape) .collect(); - let array = ArrayBuilder::new( - shape, - zarr_type, - grid.try_into().expect("Invalid chunk sizes"), - fill_value, - ) - .dimension_names(Some(dims)) - .build(store.clone(), &format!("{}/{}", group_path, name))?; - //array.store_metadata()?; + let array = ArrayBuilder::new(shape, grid, zarr_type, fill_value) + .dimension_names(Some(dims)) + .build(store.clone(), &format!("{}/{}", group_path, name))?; arrays.insert(name.to_string(), array); } Ok(arrays) diff --git a/src/storage/zarr/sync_impl.rs b/src/storage/zarr/sync_impl.rs index fb10d5c..0966c65 100644 --- a/src/storage/zarr/sync_impl.rs +++ b/src/storage/zarr/sync_impl.rs @@ -40,14 +40,11 @@ pub fn store_coords( _ => panic!("Unsupported coordinate type for {}", name), }; let name: &String = name; - let coord_array = ArrayBuilder::new( - vec![len as u64], - data_type, - vec![len as u64].try_into().expect("Invalid chunk size"), - fill_value, - ) - .dimension_names(Some(vec![name.to_string()])) - .build(store.clone(), &format!("{}/{}", group, name))?; + + let coord_array = + ArrayBuilder::new(vec![len as u64], vec![len as u64], data_type, fill_value) + .dimension_names(Some(vec![name.to_string()])) + .build(store.clone(), &format!("{}/{}", group, name))?; let subset = vec![0]; match coord { &Value::F64(ref v) => coord_array.store_chunk_elements::(&subset, v)?, @@ -534,4 +531,16 @@ impl TraceStorage for ZarrTraceStorage { } Ok((None, ())) } + + fn inspect( + &self, + traces: Vec::Finalized>>>, + ) -> Result<(Option, Self::Finalized)> { + for trace in traces { + if let Err(err) = trace { + return Ok((Some(err), ())); + }; + } + Ok((None, ())) + } } diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index 6360ab1..6cffd93 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -43,14 +43,18 @@ pub struct TransformAdaptation { } #[derive(Debug, Storable)] -pub struct Stats {} +pub struct Stats { + tuning: bool, +} impl SamplerStats for TransformAdaptation { type Stats = Stats; type StatsOptions = (); fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { - Stats {} + Stats { + tuning: self.tuning, + } } } @@ -197,7 +201,7 @@ impl AdaptStrategy for TransformAdaptation { if draw < self.final_window_size { if draw < 100 { - if (draw > 0) & draw.is_multiple_of(10) { + if (draw > 0) && draw.is_multiple_of(10) { hamiltonian.update_params( math, rng, @@ -206,7 +210,7 @@ impl AdaptStrategy for TransformAdaptation { collector.collector2.logps.iter(), )?; } - } else if (draw > 0) & draw.is_multiple_of(self.options.transform_update_freq) { + } else if (draw > 0) && draw.is_multiple_of(self.options.transform_update_freq) { hamiltonian.update_params( math, rng, diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 7b97482..792e83a 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -26,7 +26,9 @@ pub struct TransformedPoint { #[derive(Debug, Storable)] pub struct PointStats { pub fisher_distance: f64, + #[storable(dims("unconstrained_parameter"))] pub transformed_position: Option>, + #[storable(dims("unconstrained_parameter"))] pub transformed_gradient: Option>, pub transformation_index: i64, }