Skip to content

Commit d813da1

Browse files
committed
Add logic to handle conditional nodes for observed variables
1 parent e895a5c commit d813da1

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

pymc/sampling/forward.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,13 @@ def draw(
345345
return [np.stack(v) for v in drawn_values]
346346

347347

348-
def observed_dependent_deterministics(model: Model):
348+
def observed_dependent_deterministics(model: Model, extra_observeds=None):
349349
"""Find deterministics that depend directly on observed variables."""
350+
if extra_observeds is None:
351+
extra_observeds = []
352+
350353
deterministics = model.deterministics
351-
observed_rvs = set(model.observed_RVs)
354+
observed_rvs = set(model.observed_RVs + extra_observeds)
352355
blockers = model.basic_RVs
353356
return [
354357
deterministic
@@ -821,6 +824,7 @@ def sample_posterior_predictive(
821824
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
822825
if observed_data is not None:
823826
vars_ += [model[x] for x in observed_data if x in model and x not in vars_]
827+
vars_ += observed_dependent_deterministics(model, vars_)
824828

825829
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))
826830

tests/sampling/test_forward.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ def test_external_trace(self):
561561
)
562562
assert list(ppc.keys()) == ["a"]
563563

564-
@pytest.mark.xfail(reason="Auto-imputation of variables not supported in this setting")
565564
def test_external_trace_det(self):
566565
nchains = 2
567566
ndraws = 500
@@ -578,12 +577,12 @@ def test_external_trace_det(self):
578577
with pm.Model() as model_ppc:
579578
mu = pm.Normal("mu", 0.0, 1.0)
580579
a = pm.Normal("a", mu=mu, sigma=1)
581-
b = pm.Deterministic("b", a + 1)
580+
c = pm.Deterministic("c", a + 1)
582581

583582
ppc = pm.sample_posterior_predictive(
584583
trace=trace, model=model_ppc, return_inferencedata=False
585584
)
586-
assert list(ppc.keys()) == ["a", "b"]
585+
assert list(ppc.keys()) == ["a", "c"]
587586

588587
def test_normal_vector(self):
589588
with pm.Model() as model:

0 commit comments

Comments
 (0)