Skip to content

Commit 34d151a

Browse files
aseyboldttwiecki
authored andcommitted
fix: Fix random variables with missing values in pymc deterministics
1 parent 150dcee commit 34d151a

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

python/nutpie/compile_pymc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
202202
raise ValueError(f"Shared variables must have unique names: {val.name}")
203203
shared_data[val.name] = val.get_value().copy()
204204
shared_vars[val.name] = val
205+
seen.add(val)
205206

206207
for val in shared_data.values():
207208
val.flags.writeable = False

tests/test_pymc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def test_pymc_model_shared():
9999
def test_missing():
100100
with pm.Model(coords={"obs": range(4)}) as model:
101101
mu = pm.Normal("mu")
102-
pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs")
102+
y = pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs")
103+
pm.Deterministic("y2", 2 * y, dims="obs")
103104

104105
compiled = nutpie.compile_pymc_model(model)
105106
tr = nutpie.sample(compiled, chains=1, seed=1)

0 commit comments

Comments
 (0)