diff --git a/src/progress.rs b/src/progress.rs index 2403e15..2b130a4 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -49,7 +49,7 @@ impl ProgressHandler { let progress = progress_to_value(progress_update_count, self.n_cores, time_sampling, progress); let rendered = template.render_from(&self.engine, &progress).to_string(); - let rendered = rendered.unwrap_or_else(|err| format!("{}", err)); + let rendered = rendered.unwrap_or_else(|err| format!("{err}")); let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,))); progress_update_count += 1; }; diff --git a/src/pyfunc.rs b/src/pyfunc.rs index f23145e..de3ecd9 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -127,9 +127,8 @@ impl LogpError for PyLogpError { let Ok(attr) = err.value(py).getattr("is_recoverable") else { return false; }; - return attr - .is_truthy() - .expect("Could not access is_recoverable in error check"); + attr.is_truthy() + .expect("Could not access is_recoverable in error check") }), Self::ReturnTypeError() => false, Self::NotContiguousError(_) => false, @@ -151,7 +150,7 @@ impl PyDensity { transform_adapter: Option<&PyTransformAdapt>, ) -> Result { let logp_func = Python::with_gil(|py| logp_clone_func.call0(py))?; - let transform_adapter = transform_adapter.map(|val| val.clone()); + let transform_adapter = transform_adapter.cloned(); Ok(Self { logp: logp_func, transform_adapter, @@ -185,7 +184,7 @@ impl CpuLogpFunc for PyDensity { ); Ok(logp_val) } - Err(err) => return Err(PyLogpError::PyError(err)), + Err(err) => Err(PyLogpError::PyError(err)), } }) } @@ -359,7 +358,7 @@ impl TensorShape { Self { shape, dims, size } } pub fn size(&self) -> usize { - return self.size; + self.size } } @@ -617,14 +616,14 @@ impl Model for PyModel { settings: &'model S, ) -> Result> { let draws = settings.hint_num_tune() + settings.hint_num_draws(); - Ok(PyTrace::new( + PyTrace::new( rng, chain_id, self.variables.clone(), &self.make_expand_func, draws, ) - .context("Could not create PyTrace object")?) + .context("Could not create PyTrace object") } fn math(&self) -> Result> { diff --git a/src/pymc.rs b/src/pymc.rs index b33b821..685b768 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -112,7 +112,7 @@ impl LogpError for ErrorCode { } } -impl<'a> CpuLogpFunc for &'a LogpFunc { +impl CpuLogpFunc for &LogpFunc { type LogpError = ErrorCode; type TransformParams = (); @@ -175,7 +175,7 @@ impl<'model> DrawStorage for PyMcTrace<'model> { let num_arrays = data.len() / size; let data = Float64Array::from(data); let item_field = Arc::new(Field::new("item", DataType::Float64, false)); - let offsets = OffsetBuffer::from_lengths((0..num_arrays).into_iter().map(|_| size)); + let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size)); let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None); let field = Field::new(name, DataType::LargeList(item_field), false); (Arc::new(field), Arc::new(array) as Arc) diff --git a/src/stan.rs b/src/stan.rs index e258096..0586e6f 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::{ffi::CString, path::PathBuf}; -use anyhow::Context; +use anyhow::{bail, Context}; use arrow::array::{Array, FixedSizeListArray, Float64Array, StructArray}; use arrow::datatypes::{DataType, Field}; use bridgestan::open_library; @@ -26,7 +26,7 @@ type InnerModel = bridgestan::Model>; #[derive(Clone)] pub struct StanLibrary(Arc); -#[derive(Clone)] +#[derive(Clone, Debug)] struct Parameter { name: String, shape: Vec, @@ -40,7 +40,7 @@ impl StanLibrary { #[new] fn new(path: PathBuf) -> PyResult { let lib = open_library(path) - .map_err(|e| PyValueError::new_err(format!("Could not open stan libray: {}", e)))?; + .map_err(|e| PyValueError::new_err(format!("Could not open stan libray: {e}")))?; Ok(Self(Arc::new(lib))) } } @@ -64,6 +64,16 @@ impl StanVariable { fn size(&self) -> usize { self.0.size } + + #[getter] + fn start_idx(&self) -> usize { + self.0.start_idx + } + + #[getter] + fn end_idx(&self) -> usize { + self.0.end_idx + } } #[pyclass] @@ -75,68 +85,155 @@ pub struct StanModel { } /// Return meta information about the constrained parameters of the model -fn params( - model: &InnerModel, - include_tp: bool, - include_gq: bool, -) -> anyhow::Result> { - let var_string = model.param_names(include_tp, include_gq); - let name_idxs: anyhow::Result)>> = var_string +fn params(var_string: &str) -> anyhow::Result> { + // Parse each variable string into (name, is_complex, indices) + let parsed_variables: anyhow::Result)>> = var_string .split(',') .map(|var| { - let mut parts = var.split('.'); - let name = parts - .next() - .ok_or_else(|| anyhow::Error::msg("Invalid parameter name"))?; - let idxs: anyhow::Result> = parts - .map(|mut idx| { - if idx == "real" { - idx = "1"; - } - if idx == "imag" { - idx = "2"; - } - let idx: usize = idx - .parse() - .map_err(|_| anyhow::Error::msg("Invalid parameter name"))?; - Ok(idx - 1) - }) - .collect(); - Ok((name, idxs?)) + let mut indices = vec![]; + let mut remaining = var; + let mut complex_suffix = None; + + // Parse from right to left, extracting indices and checking for complex type + while let Some(idx) = remaining.rfind('.') { + let suffix = &remaining[(idx + 1)..]; + + // Handle complex number suffixes + if suffix == "real" || suffix == "imag" { + complex_suffix = Some(suffix); + remaining = &remaining[..idx]; + continue; + } + + // Try to parse as index + if let Ok(index) = suffix.parse::() { + // Convert from 1-based to 0-based indexing + let zero_based_idx = index.checked_sub(1).ok_or_else(|| { + anyhow::Error::msg("Invalid parameter index (must be > 0)") + })?; + + indices.push(zero_based_idx); + remaining = &remaining[..idx]; + } else { + // Not a number - this is part of the variable name + break; + } + } + + // Variable name is what remains + let name = remaining.trim().to_string(); + + // Reverse indices since we parsed right-to-left + indices.reverse(); + + Ok((name, complex_suffix.is_some(), indices)) }) .collect(); + // Group variables by name and build Parameter objects let mut variables = Vec::new(); let mut start_idx = 0; - for (name, idxs) in &name_idxs?.iter().chunk_by(|(name, _)| name) { - let mut shape: Vec = idxs - .map(|(_name, idx)| idx) - .fold(None, |acc, elem| { - let mut shape = acc.unwrap_or(elem.clone()); - shape - .iter_mut() - .zip_eq(elem.iter()) - .for_each(|(old, &new)| { - *old = new.max(*old); - }); - Some(shape) - }) - .unwrap_or(vec![]); - shape.iter_mut().for_each(|max_idx| *max_idx += 1); + + for (name, group) in &parsed_variables?.iter().chunk_by(|(name, _, _)| name) { + // Find maximum shape and check if this is a complex variable + let (shape, is_complex) = determine_variable_shape(group) + .context(format!("Error while parsing stan variable {name}"))?; + + // Calculate total size of this variable let size = shape.iter().product(); - let end_idx = start_idx + size; - variables.push(Parameter { - name: name.to_string(), - shape, - size, - start_idx, - end_idx, - }); + let mut end_idx = start_idx + size; + + // Create Parameter objects (one for real and one for imag if complex) + if is_complex { + variables.push(Parameter { + name: format!("{name}.real"), + shape: shape.clone(), + size, + start_idx, + end_idx, + }); + start_idx = end_idx; + end_idx = start_idx + size; + variables.push(Parameter { + name: format!("{name}.imag"), + shape, + size, + start_idx, + end_idx, + }); + } else { + variables.push(Parameter { + name: name.to_string(), + shape, + size, + start_idx, + end_idx, + }); + } + + // Move to the next variable start_idx = end_idx; } + Ok(variables) } +// Helper function to determine the shape and complex flag for a group of variables +fn determine_variable_shape<'a, I>(group: I) -> anyhow::Result<(Vec, bool)> +where + I: Iterator)>, +{ + let group = group.collect_vec(); + + let (mut shape, is_complex) = group + .iter() + .map(|&(_, is_complex, ref idx)| (idx, is_complex)) + .fold(None, |acc, (elem_index, &elem_is_complex)| { + let (mut shape, is_complex) = acc.unwrap_or((elem_index.clone(), elem_is_complex)); + assert!( + is_complex == elem_is_complex, + "Inconsistent complex flags for same variable" + ); + + // Find maximum index in each dimension + shape + .iter_mut() + .zip_eq(elem_index.iter()) + .for_each(|(old, &new)| { + *old = new.max(*old); + }); + + Some((shape, is_complex)) + }) + .expect("List of variable entries cannot be empty"); + + shape.iter_mut().for_each(|max_idx| *max_idx += 1); + + // Check if the indices are in Fortran order + let mut expected_index: Vec = vec![0; shape.len()]; + let mut expect_imag = false; + for (_, _, idx) in group.iter() { + if idx != &expected_index { + bail!("Stan returned data that was not in the expected order.") + } + if is_complex { + expect_imag = !expect_imag; + } + if !expect_imag { + // increment expected index + for i in 0..shape.len() { + if expected_index[i] < shape[i] - 1 { + expected_index[i] += 1; + break; + } else { + expected_index[i] = 0; + } + } + } + } + + Ok((shape, is_complex)) +} #[pymethods] impl StanModel { #[new] @@ -155,7 +252,9 @@ impl StanModel { let model = Arc::new( bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?, ); - let variables = params(&model, true, true)?; + + let var_string = model.param_names(true, true); + let variables = params(var_string)?; let transform_adapter = transform_adapter.map(PyTransformAdapt::new); Ok(StanModel { model, @@ -531,8 +630,8 @@ impl Model for StanModel { fn new_trace<'a, S: Settings, R: rand::Rng + ?Sized>( &'a self, - _rng: &mut R, - chain: u64, + rng: &mut R, + _chain: u64, settings: &S, ) -> anyhow::Result> { let draws = settings.hint_num_tune() + settings.hint_num_draws(); @@ -541,7 +640,8 @@ impl Model for StanModel { .iter() .map(|var| Vec::with_capacity(var.size * draws)) .collect(); - let rng = self.model.new_rng(chain as u32)?; + let seed = rng.next_u32(); + let rng = self.model.new_rng(seed)?; let buffer = vec![0f64; self.model.param_num(true, true)]; Ok(StanTrace { model: self, @@ -555,7 +655,7 @@ impl Model for StanModel { fn math(&self) -> anyhow::Result> { Ok(CpuMath::new(StanDensity { inner: &self.model, - transform_adapter: self.transform_adapter.as_ref().map(|v| v.clone()), + transform_adapter: self.transform_adapter.clone(), })) } @@ -605,7 +705,6 @@ mod tests { 0., 6., 12., 18., 24., 2., 8., 14., 20., 26., 4., 10., 16., 22., 28., 1., 7., 13., 19., 25., 3., 9., 15., 21., 27., 5., 11., 17., 23., 29., ]; - dbg!(&out); assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b)); let data = vec![ @@ -618,7 +717,6 @@ mod tests { 0., 6., 12., 18., 24., 2., 8., 14., 20., 26., 4., 10., 16., 22., 28., 1., 7., 13., 19., 25., 3., 9., 15., 21., 27., 5., 11., 17., 23., 29., ]; - dbg!(&out); assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b)); let data = vec![ @@ -631,7 +729,358 @@ mod tests { 0., 15., 5., 20., 10., 25., 1., 16., 6., 21., 11., 26., 2., 17., 7., 22., 12., 27., 3., 18., 8., 23., 13., 28., 4., 19., 9., 24., 14., 29., ]; - dbg!(&out); assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b)); } + + #[test] + fn parse_vars() { + let vars = "x.1.1,x.2.1,x.3.1,x.1.2,x.2.2,x.3.2"; + let parsed = super::params(vars).unwrap(); + assert!(parsed.len() == 1); + let parsed = parsed[0].clone(); + assert!(parsed.name == "x"); + assert!(parsed.shape == vec![3, 2]); + + // Incorrect order + let vars = "x.1.2,x.1.1,x.2.1,x.2.2,x.3.1,x.3.2"; + assert!(super::params(vars).is_err()); + + // Incorrect order + let vars = "x.1.2.real,x.1.2.imag"; + assert!(super::params(vars).is_err()); + + let vars = "x.1.1.real,x.1.1.imag,x.2.1.real,x.2.1.imag,x.3.1.real,x.3.1.imag"; + let parsed = super::params(vars).unwrap(); + assert!(parsed.len() == 2); + let var = parsed[0].clone(); + assert!(var.name == "x.real"); + assert!(var.shape == vec![3, 1]); + + let var = parsed[1].clone(); + assert!(var.name == "x.imag"); + assert!(var.shape == vec![3, 1]); + + // Test single variable + let vars = "alpha"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 1); + let var = &parsed[0]; + assert_eq!(var.name, "alpha"); + assert_eq!(var.shape, Vec::::new()); + assert_eq!(var.size, 1); + + // Test multiple scalar variables + let vars = "alpha,beta,gamma"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 3); + assert_eq!(parsed[0].name, "alpha"); + assert_eq!(parsed[1].name, "beta"); + assert_eq!(parsed[2].name, "gamma"); + + // Test 1D array + let vars = "theta.1,theta.2,theta.3,theta.4"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 1); + let var = &parsed[0]; + assert_eq!(var.name, "theta"); + assert_eq!(var.shape, vec![4]); + assert_eq!(var.size, 4); + + // Test variable name with colons and dots + let vars = "x:1:2.4:1.1,x:1:2.4:1.2,x:1:2.4:1.3"; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed.len(), 1); + let var = &parsed[0]; + assert_eq!(var.name, "x:1:2.4:1"); + assert_eq!(var.shape, vec![3]); + assert_eq!(var.size, 3); + + let vars = " + a, + base, + base_i, + pair:1, + pair:2, + nested:1, + nested:2:1, + nested:2:2.real, + nested:2:2.imag, + arr_pair.1:1, + arr_pair.1:2, + arr_pair.2:1, + arr_pair.2:2, + arr_very_nested.1:1:1, + arr_very_nested.1:1:2:1, + arr_very_nested.1:1:2:2.real, + arr_very_nested.1:1:2:2.imag, + arr_very_nested.1:2, + arr_very_nested.2:1:1, + arr_very_nested.2:1:2:1, + arr_very_nested.2:1:2:2.real, + arr_very_nested.2:1:2:2.imag, + arr_very_nested.2:2, + arr_very_nested.3:1:1, + arr_very_nested.3:1:2:1, + arr_very_nested.3:1:2:2.real, + arr_very_nested.3:1:2:2.imag, + arr_very_nested.3:2, + arr_2d_pair.1.1:1, + arr_2d_pair.1.1:2, + arr_2d_pair.2.1:1, + arr_2d_pair.2.1:2, + arr_2d_pair.3.1:1, + arr_2d_pair.3.1:2, + arr_2d_pair.1.2:1, + arr_2d_pair.1.2:2, + arr_2d_pair.2.2:1, + arr_2d_pair.2.2:2, + arr_2d_pair.3.2:1, + arr_2d_pair.3.2:2, + basep1, + basep2, + basep3, + basep4, + basep5, + ultimate.1.1:1.1:1, + ultimate.1.1:1.1:2.1, + ultimate.1.1:1.1:2.2, + ultimate.1.1:1.2:1, + ultimate.1.1:1.2:2.1, + ultimate.1.1:1.2:2.2, + ultimate.1.1:2.1.1, + ultimate.1.1:2.2.1, + ultimate.1.1:2.3.1, + ultimate.1.1:2.4.1, + ultimate.1.1:2.1.2, + ultimate.1.1:2.2.2, + ultimate.1.1:2.3.2, + ultimate.1.1:2.4.2, + ultimate.1.1:2.1.3, + ultimate.1.1:2.2.3, + ultimate.1.1:2.3.3, + ultimate.1.1:2.4.3, + ultimate.1.1:2.1.4, + ultimate.1.1:2.2.4, + ultimate.1.1:2.3.4, + ultimate.1.1:2.4.4, + ultimate.1.1:2.1.5, + ultimate.1.1:2.2.5, + ultimate.1.1:2.3.5, + ultimate.1.1:2.4.5, + ultimate.2.1:1.1:1, + ultimate.2.1:1.1:2.1, + ultimate.2.1:1.1:2.2, + ultimate.2.1:1.2:1, + ultimate.2.1:1.2:2.1, + ultimate.2.1:1.2:2.2, + ultimate.2.1:2.1.1, + ultimate.2.1:2.2.1, + ultimate.2.1:2.3.1, + ultimate.2.1:2.4.1, + ultimate.2.1:2.1.2, + ultimate.2.1:2.2.2, + ultimate.2.1:2.3.2, + ultimate.2.1:2.4.2, + ultimate.2.1:2.1.3, + ultimate.2.1:2.2.3, + ultimate.2.1:2.3.3, + ultimate.2.1:2.4.3, + ultimate.2.1:2.1.4, + ultimate.2.1:2.2.4, + ultimate.2.1:2.3.4, + ultimate.2.1:2.4.4, + ultimate.2.1:2.1.5, + ultimate.2.1:2.2.5, + ultimate.2.1:2.3.5, + ultimate.2.1:2.4.5, + ultimate.1.2:1.1:1, + ultimate.1.2:1.1:2.1, + ultimate.1.2:1.1:2.2, + ultimate.1.2:1.2:1, + ultimate.1.2:1.2:2.1, + ultimate.1.2:1.2:2.2, + ultimate.1.2:2.1.1, + ultimate.1.2:2.2.1, + ultimate.1.2:2.3.1, + ultimate.1.2:2.4.1, + ultimate.1.2:2.1.2, + ultimate.1.2:2.2.2, + ultimate.1.2:2.3.2, + ultimate.1.2:2.4.2, + ultimate.1.2:2.1.3, + ultimate.1.2:2.2.3, + ultimate.1.2:2.3.3, + ultimate.1.2:2.4.3, + ultimate.1.2:2.1.4, + ultimate.1.2:2.2.4, + ultimate.1.2:2.3.4, + ultimate.1.2:2.4.4, + ultimate.1.2:2.1.5, + ultimate.1.2:2.2.5, + ultimate.1.2:2.3.5, + ultimate.1.2:2.4.5, + ultimate.2.2:1.1:1, + ultimate.2.2:1.1:2.1, + ultimate.2.2:1.1:2.2, + ultimate.2.2:1.2:1, + ultimate.2.2:1.2:2.1, + ultimate.2.2:1.2:2.2, + ultimate.2.2:2.1.1, + ultimate.2.2:2.2.1, + ultimate.2.2:2.3.1, + ultimate.2.2:2.4.1, + ultimate.2.2:2.1.2, + ultimate.2.2:2.2.2, + ultimate.2.2:2.3.2, + ultimate.2.2:2.4.2, + ultimate.2.2:2.1.3, + ultimate.2.2:2.2.3, + ultimate.2.2:2.3.3, + ultimate.2.2:2.4.3, + ultimate.2.2:2.1.4, + ultimate.2.2:2.2.4, + ultimate.2.2:2.3.4, + ultimate.2.2:2.4.4, + ultimate.2.2:2.1.5, + ultimate.2.2:2.2.5, + ultimate.2.2:2.3.5, + ultimate.2.2:2.4.5, + ultimate.1.3:1.1:1, + ultimate.1.3:1.1:2.1, + ultimate.1.3:1.1:2.2, + ultimate.1.3:1.2:1, + ultimate.1.3:1.2:2.1, + ultimate.1.3:1.2:2.2, + ultimate.1.3:2.1.1, + ultimate.1.3:2.2.1, + ultimate.1.3:2.3.1, + ultimate.1.3:2.4.1, + ultimate.1.3:2.1.2, + ultimate.1.3:2.2.2, + ultimate.1.3:2.3.2, + ultimate.1.3:2.4.2, + ultimate.1.3:2.1.3, + ultimate.1.3:2.2.3, + ultimate.1.3:2.3.3, + ultimate.1.3:2.4.3, + ultimate.1.3:2.1.4, + ultimate.1.3:2.2.4, + ultimate.1.3:2.3.4, + ultimate.1.3:2.4.4, + ultimate.1.3:2.1.5, + ultimate.1.3:2.2.5, + ultimate.1.3:2.3.5, + ultimate.1.3:2.4.5, + ultimate.2.3:1.1:1, + ultimate.2.3:1.1:2.1, + ultimate.2.3:1.1:2.2, + ultimate.2.3:1.2:1, + ultimate.2.3:1.2:2.1, + ultimate.2.3:1.2:2.2, + ultimate.2.3:2.1.1, + ultimate.2.3:2.2.1, + ultimate.2.3:2.3.1, + ultimate.2.3:2.4.1, + ultimate.2.3:2.1.2, + ultimate.2.3:2.2.2, + ultimate.2.3:2.3.2, + ultimate.2.3:2.4.2, + ultimate.2.3:2.1.3, + ultimate.2.3:2.2.3, + ultimate.2.3:2.3.3, + ultimate.2.3:2.4.3, + ultimate.2.3:2.1.4, + ultimate.2.3:2.2.4, + ultimate.2.3:2.3.4, + ultimate.2.3:2.4.4, + ultimate.2.3:2.1.5, + ultimate.2.3:2.2.5, + ultimate.2.3:2.3.5, + ultimate.2.3:2.4.5 + "; + let parsed = super::params(vars).unwrap(); + assert_eq!(parsed[0].name, "a"); + assert_eq!(parsed[0].shape, vec![0usize; 0]); + + assert_eq!(parsed[1].name, "base"); + assert_eq!(parsed[1].shape, vec![0usize; 0]); + + assert_eq!(parsed[2].name, "base_i"); + assert_eq!(parsed[2].shape, vec![0usize; 0]); + + assert_eq!(parsed[3].name, "pair:1"); + assert_eq!(parsed[3].shape, vec![0usize; 0]); + + assert_eq!(parsed[4].name, "pair:2"); + assert_eq!(parsed[4].shape, vec![0usize; 0]); + + assert_eq!(parsed[5].name, "nested:1"); + assert_eq!(parsed[5].shape, vec![0usize; 0]); + + assert_eq!(parsed[6].name, "nested:2:1"); + assert_eq!(parsed[6].shape, vec![0usize; 0]); + + assert_eq!(parsed[7].name, "nested:2:2.real"); + assert_eq!(parsed[7].shape, vec![0usize; 0]); + + assert_eq!(parsed[8].name, "nested:2:2.imag"); + assert_eq!(parsed[8].shape, vec![0usize; 0]); + + assert_eq!(parsed[9].name, "arr_pair.1:1"); + assert_eq!(parsed[9].shape, vec![0usize; 0]); + + assert_eq!(parsed[10].name, "arr_pair.1:2"); + assert_eq!(parsed[10].shape, vec![0usize; 0]); + + assert_eq!(parsed[11].name, "arr_pair.2:1"); + assert_eq!(parsed[11].shape, vec![0usize; 0]); + + assert_eq!(parsed[12].name, "arr_pair.2:2"); + assert_eq!(parsed[12].shape, vec![0usize; 0]); + + assert_eq!(parsed[13].name, "arr_very_nested.1:1:1"); + assert_eq!(parsed[13].shape, vec![0usize; 0]); + + assert_eq!(parsed[14].name, "arr_very_nested.1:1:2:1"); + assert_eq!(parsed[14].shape, vec![0usize; 0]); + + assert_eq!(parsed[15].name, "arr_very_nested.1:1:2:2.real"); + assert_eq!(parsed[15].shape, vec![0usize; 0]); + + assert_eq!(parsed[16].name, "arr_very_nested.1:1:2:2.imag"); + assert_eq!(parsed[16].shape, vec![0usize; 0]); + + assert_eq!(parsed[17].name, "arr_very_nested.1:2"); + assert_eq!(parsed[17].shape, vec![0usize; 0]); + + assert_eq!(parsed[18].name, "arr_very_nested.2:1:1"); + assert_eq!(parsed[18].shape, vec![0usize; 0]); + + assert_eq!(parsed[19].name, "arr_very_nested.2:1:2:1"); + assert_eq!(parsed[19].shape, vec![0usize; 0]); + + assert_eq!(parsed[20].name, "arr_very_nested.2:1:2:2.real"); + assert_eq!(parsed[20].shape, vec![0usize; 0]); + + assert_eq!(parsed[21].name, "arr_very_nested.2:1:2:2.imag"); + assert_eq!(parsed[21].shape, vec![0usize; 0]); + + assert_eq!(parsed[22].name, "arr_very_nested.2:2"); + assert_eq!(parsed[22].shape, vec![0usize; 0]); + + assert_eq!(parsed[23].name, "arr_very_nested.3:1:1"); + assert_eq!(parsed[23].shape, vec![0usize; 0]); + + assert_eq!(parsed[24].name, "arr_very_nested.3:1:2:1"); + assert_eq!(parsed[24].shape, vec![0usize; 0]); + + assert_eq!(parsed[25].name, "arr_very_nested.3:1:2:2.real"); + assert_eq!(parsed[25].shape, vec![0usize; 0]); + + assert_eq!(parsed[26].name, "arr_very_nested.3:1:2:2.imag"); + assert_eq!(parsed[26].shape, vec![0usize; 0]); + + assert_eq!(parsed[27].name, "arr_very_nested.3:2"); + assert_eq!(parsed[27].shape, vec![0usize; 0]); + } } diff --git a/tests/test_stan.py b/tests/test_stan.py index 89201c5..f44b755 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -27,6 +27,127 @@ def test_stan_model(): trace.posterior.a # noqa: B018 +@pytest.mark.stan +def test_seed(): + model = """ + data {} + parameters { + real a; + } + model { + a ~ normal(0, 1); + } + generated quantities { + real b = normal_rng(0, 1); + } + """ + + compiled_model = nutpie.compile_stan_model(code=model) + trace = nutpie.sample(compiled_model, seed=42) + trace2 = nutpie.sample(compiled_model, seed=42) + trace3 = nutpie.sample(compiled_model, seed=43) + + assert np.allclose(trace.posterior.a, trace2.posterior.a) + assert np.allclose(trace.posterior.b, trace2.posterior.b) + + assert not np.allclose(trace.posterior.a, trace3.posterior.a) + assert not np.allclose(trace.posterior.b, trace3.posterior.b) + # Check that all chains are pairwise different + for i in range(len(trace.posterior.a)): + for j in range(i + 1, len(trace.posterior.a)): + assert not np.allclose(trace.posterior.a[i], trace.posterior.a[j]) + assert not np.allclose(trace.posterior.b[i], trace.posterior.b[j]) + # Check that all chains are pairwise different between seeds + for i in range(len(trace.posterior.a)): + for j in range(len(trace3.posterior.a)): + assert not np.allclose(trace.posterior.a[i], trace3.posterior.a[j]) + assert not np.allclose(trace.posterior.b[i], trace3.posterior.b[j]) + + +@pytest.mark.stan +def test_nested(): + # Adapted from + # https://github.com/stan-dev/stanio/blob/main/test/data/tuples/output.stan + model = """ + parameters { + real a; + } + model { + a ~ normal(0, 1); + } + generated quantities { + real base = normal_rng(0, 1); + int base_i = to_int(normal_rng(10, 10)); + + tuple(real, real) pair = (base, base * 2); + + tuple(real, tuple(int, complex)) nested = (base * 3, (base_i, base * 4.0i)); + array[2] tuple(real, real) arr_pair = {pair, (base * 5, base * 6)}; + + array[3] tuple(tuple(real, tuple(int, complex)), real) arr_very_nested + = {(nested, base*7), ((base*8, (base_i*2, base*9.0i)), base * 10), (nested, base*11)}; + + array[3,2] tuple(real, real) arr_2d_pair = {{(base * 12, base * 13), (base * 14, base * 15)}, + {(base * 16, base * 17), (base * 18, base * 19)}, + {(base * 20, base * 21), (base * 22, base * 23)}}; + + real basep1 = base + 1, basep2 = base + 2; + real basep3 = base + 3, basep4 = base + 4, basep5 = base + 5; + array[2,3] tuple(array[2] tuple(real, vector[2]), matrix[4,5]) ultimate = + { + {( + {(base, [base *2, base *3]'), (base *4, [base*5, base*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * base + ), + ( + {(basep1, [basep1 *2, basep1 *3]'), (basep1 *4, [basep1*5, basep1*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep1 + ), + ( + {(basep2, [basep2 *2, basep2 *3]'), (basep2 *4, [basep2*5, basep2*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep2 + ) + }, + {( + {(basep3, [basep3 *2, basep3 *3]'), (basep3 *4, [basep3*5, basep3*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep3 + ), + ( + {(basep4, [basep4 *2, basep4 *3]'), (basep4 *4, [basep4*5, basep4*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep4 + ), + ( + {(basep5, [basep5 *2, basep5 *3]'), (basep5 *4, [basep5*5, basep5*6]')}, + to_matrix(linspaced_vector(20, 7, 11), 4, 5) * basep5 + ) + }}; + } + """ + + compiled = nutpie.compile_stan_model(code=model) + tr = nutpie.sample(compiled, chains=6) + base = tr.posterior.base + + assert np.allclose(tr.posterior["nested:2:2.imag"], 4 * base) + assert np.allclose(tr.posterior["nested:2:2.real"], 0.0) + + assert np.allclose(tr.posterior["ultimate.1.1:1.1:1"], base) + assert np.allclose(tr.posterior["ultimate.1.2:1.1:1"], base + 1) + assert np.allclose(tr.posterior["ultimate.1.3:1.1:1"], base + 2) + assert np.allclose(tr.posterior["ultimate.2.1:1.1:1"], base + 3) + assert np.allclose(tr.posterior["ultimate.2.2:1.1:1"], base + 4) + assert np.allclose(tr.posterior["ultimate.2.3:1.1:1"], base + 5) + + assert tr.posterior["ultimate.2.1:1.1:2"].shape == (6, 1000, 2) + assert np.allclose( + tr.posterior["ultimate.2.3:1.1:2"].values[:, :, 0], 2 * (base + 5) + ) + assert np.allclose( + tr.posterior["ultimate.2.3:1.1:2"].values[:, :, 1], 3 * (base + 5) + ) + assert np.allclose(tr.posterior["base_i"], tr.posterior.base_i.astype(int)) + + @pytest.mark.stan def test_stan_model_data(): model = """