Skip to content

Commit 6073d8d

Browse files
committed
Fix bug in compute_deterministics
1 parent 7ffd47d commit 6073d8d

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

pymc/sampling/deterministic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def compute_deterministics(
8484

8585
if var_names is None:
8686
deterministics = model.deterministics
87+
var_names = [det.name for det in deterministics]
8788
else:
8889
deterministics = [model[var_name] for var_name in var_names]
8990
if not set(deterministics).issubset(set(model.deterministics)):
@@ -101,7 +102,7 @@ def compute_deterministics(
101102
new_dataset = apply_function_over_dataset(
102103
fn,
103104
dataset[[rv.name for rv in model.free_RVs]],
104-
output_var_names=[det.name for det in model.deterministics],
105+
output_var_names=var_names,
105106
dims=dims,
106107
coords=coords,
107108
sample_dims=sample_dims,

tests/sampling/test_deterministic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def test_compute_deterministics():
5959
assert extended_with_mu["mu"].dims == ("chain", "draw", "group")
6060
assert_allclose(extended_with_mu["mu"], dataset["mu_raw"].cumsum("group"))
6161

62+
only_sigma = compute_deterministics(dataset, var_names=["sigma"], model=m, progressbar=False)
63+
assert set(only_sigma.data_vars.variables) == {"sigma"}
64+
assert only_sigma["sigma"].dims == ("chain", "draw")
65+
assert_allclose(only_sigma["sigma"], np.exp(dataset["sigma_raw"]))
66+
6267

6368
def test_docstring_example():
6469
import pymc as pm

0 commit comments

Comments
 (0)