Skip to content

Commit 73374f5

Browse files
committed
fix(pymc): Allow dims with only length specified
1 parent 2938d5a commit 73374f5

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

python/nutpie/compile_pymc.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,12 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
207207
)
208208
expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)
209209

210-
coords = {name: pd.Index(vals) for name, vals in model.coords.items()}
210+
coords = {}
211+
for name, vals in model.coords.items():
212+
if vals is None:
213+
vals = pd.RangeIndex(int(model.dim_lengths[name].eval()))
214+
coords[name] = pd.Index(vals)
215+
211216
if "unconstrained_parameter" in coords:
212217
raise ValueError("Model contains invalid name 'unconstrained_parameter'.")
213218

tests/test_pymc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ def test_pymc_model():
1515
trace.posterior.a
1616

1717

18+
def test_pymc_model():
19+
with pm.Model() as model:
20+
model.add_coord("foo", length=5)
21+
pm.Normal("a", dims="foo")
22+
23+
compiled = nutpie.compile_pymc_model(model)
24+
trace = nutpie.sample(compiled, chains=1)
25+
trace.posterior.a
26+
27+
1828
def test_trafo():
1929
with pm.Model() as model:
2030
pm.Uniform("a")

0 commit comments

Comments
 (0)