Skip to content

Commit fcf4f7b

Browse files
committed
fix: allow variables with zero shapes
1 parent 59a5ba7 commit fcf4f7b

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

src/pymc.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use pyo3::{
1515
Bound, Py, PyAny, PyObject, PyResult, Python,
1616
};
1717

18+
use rand_distr::num_traits::CheckedEuclid;
1819
use thiserror::Error;
1920

2021
type UserData = *const std::ffi::c_void;
@@ -128,8 +129,8 @@ impl CpuLogpFunc for &LogpFunc {
128129
let retcode = unsafe {
129130
(self.func)(
130131
self.dim,
131-
&position[0] as *const f64,
132-
&mut gradient[0] as *mut f64,
132+
position.as_ptr(),
133+
gradient.as_mut_ptr(),
133134
logp_ptr,
134135
self.user_data_ptr,
135136
)
@@ -148,6 +149,7 @@ pub(crate) struct PyMcTrace<'model> {
148149
var_sizes: Vec<usize>,
149150
var_names: Vec<String>,
150151
expand: &'model ExpandFunc,
152+
count: usize,
151153
}
152154

153155
impl<'model> DrawStorage for PyMcTrace<'model> {
@@ -165,14 +167,20 @@ impl<'model> DrawStorage for PyMcTrace<'model> {
165167
data.extend_from_slice(vals);
166168
start = end;
167169
}
170+
self.count += 1;
171+
168172
Ok(())
169173
}
170174

171175
fn finalize(self) -> Result<Arc<dyn Array>> {
172176
let (fields, arrays): (Vec<_>, _) = izip!(self.data, self.var_names, self.var_sizes)
173177
.map(|(data, name, size)| {
174-
assert!(data.len() % size == 0);
175-
let num_arrays = data.len() / size;
178+
let (num_arrays, rem) = data
179+
.len()
180+
.checked_div_rem_euclid(&size)
181+
.unwrap_or((self.count, 0));
182+
assert!(rem == 0);
183+
assert!(num_arrays == self.count);
176184
let data = Float64Array::from(data);
177185
let item_field = Arc::new(Field::new("item", DataType::Float64, false));
178186
let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size));
@@ -206,6 +214,7 @@ impl<'model> PyMcTrace<'model> {
206214
var_sizes: model.var_sizes.clone(),
207215
var_names: model.var_names.clone(),
208216
expand: &model.expand,
217+
count: 0,
209218
}
210219
}
211220

src/stan.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ pub struct StanModel {
8686

8787
/// Return meta information about the constrained parameters of the model
8888
fn params(var_string: &str) -> anyhow::Result<Vec<Parameter>> {
89+
if var_string.is_empty() {
90+
return Ok(vec![]);
91+
}
8992
// Parse each variable string into (name, is_complex, indices)
9093
let parsed_variables: anyhow::Result<Vec<(String, bool, Vec<usize>)>> = var_string
9194
.split(',')
@@ -540,6 +543,7 @@ pub struct StanTrace<'model> {
540543
trace: Vec<Vec<f64>>,
541544
expanded_buffer: Box<[f64]>,
542545
rng: bridgestan::Rng<&'model bridgestan::StanLibrary>,
546+
count: usize,
543547
}
544548

545549
impl<'model> Clone for StanTrace<'model> {
@@ -559,6 +563,7 @@ impl<'model> Clone for StanTrace<'model> {
559563
trace: self.trace.clone(),
560564
expanded_buffer: self.expanded_buffer.clone(),
561565
rng,
566+
count: self.count,
562567
}
563568
}
564569
}
@@ -591,6 +596,7 @@ impl<'model> DrawStorage for StanTrace<'model> {
591596
// We need to transpose
592597
fortran_to_c_order(slice, &var.shape, trace);
593598
}
599+
self.count += 1;
594600
Ok(())
595601
}
596602

@@ -613,7 +619,7 @@ impl<'model> DrawStorage for StanTrace<'model> {
613619
.unzip();
614620

615621
Ok(Arc::new(
616-
StructArray::try_new(fields.into(), arrays, None)
622+
StructArray::try_new_with_length(fields.into(), arrays, None, self.count)
617623
.context("Could not create arrow StructArray")?,
618624
))
619625
}
@@ -649,6 +655,7 @@ impl Model for StanModel {
649655
trace,
650656
rng,
651657
expanded_buffer: buffer.into(),
658+
count: 0,
652659
})
653660
}
654661

@@ -734,6 +741,10 @@ mod tests {
734741

735742
#[test]
736743
fn parse_vars() {
744+
let vars = "";
745+
let parsed = super::params(vars).unwrap();
746+
assert!(parsed.len() == 0);
747+
737748
let vars = "x.1.1,x.2.1,x.3.1,x.1.2,x.2.2,x.3.2";
738749
let parsed = super::params(vars).unwrap();
739750
assert!(parsed.len() == 1);

tests/test_pymc.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,23 @@ def test_pymc_model(backend, gradient_backend):
3131
trace.posterior.a # noqa: B018
3232

3333

34+
@pytest.mark.pymc
35+
@parameterize_backends
36+
def test_zero_size(backend, gradient_backend):
37+
import pytensor.tensor as pt
38+
39+
with pm.Model() as model:
40+
a = pm.Normal("a", shape=(0, 0, 10))
41+
pm.Deterministic("b", pt.exp(a))
42+
43+
compiled = nutpie.compile_pymc_model(
44+
model, backend=backend, gradient_backend=gradient_backend
45+
)
46+
trace = nutpie.sample(compiled, chains=1, draws=17, tune=100)
47+
assert trace.posterior.a.shape == (1, 17, 0, 0, 10)
48+
assert trace.posterior.b.shape == (1, 17, 0, 0, 10)
49+
50+
3451
@pytest.mark.pymc
3552
@parameterize_backends
3653
def test_pymc_model_float32(backend, gradient_backend):

tests/test_stan.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,25 @@ def test_stan_model():
2727
trace.posterior.a # noqa: B018
2828

2929

30+
@pytest.mark.stan
31+
def test_empty():
32+
model = """
33+
data {}
34+
parameters {
35+
array[0] real a;
36+
}
37+
model {
38+
a ~ normal(0, 1);
39+
}
40+
"""
41+
42+
compiled_model = nutpie.compile_stan_model(code=model)
43+
trace = nutpie.sample(compiled_model) # noqa: F841
44+
# TODO: Variable `a` is missing because of this bridgestan issue:
45+
# https://github.com/roualdes/bridgestan/issues/278
46+
# assert trace.posterior.a.shape == (0, 1000)
47+
48+
3049
@pytest.mark.stan
3150
def test_seed():
3251
model = """

0 commit comments

Comments
 (0)