Skip to content

Commit a592208

Browse files
committed
Check for observed variables in the trace as well as the model
1 parent e6767ab commit a592208

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

pymc/sampling/forward.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,13 +767,15 @@ def sample_posterior_predictive(
767767
if "coords" not in idata_kwargs:
768768
idata_kwargs["coords"] = {}
769769
idata: InferenceData | None = None
770+
observed_data = None
770771
stacked_dims = None
771772
if isinstance(trace, InferenceData):
772773
_constant_data = getattr(trace, "constant_data", None)
773774
if _constant_data is not None:
774775
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
775776
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
776777
idata = trace
778+
observed_data = trace["observed_data"]
777779
trace = trace["posterior"]
778780
if isinstance(trace, xarray.Dataset):
779781
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
@@ -817,6 +819,8 @@ def sample_posterior_predictive(
817819
vars_ = [model[x] for x in var_names]
818820
else:
819821
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
822+
if observed_data is not None:
823+
vars_ += [model[x] for x in observed_data if x in model]
820824

821825
vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))
822826

tests/sampling/test_forward.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,16 @@ def test_normal_scalar(self):
481481
chains=nchains,
482482
)
483483

484+
# test that trace is used in ppc
485+
with pm.Model() as model_ppc:
486+
mu = pm.Normal("mu", 0.0, 1.0)
487+
a = pm.Normal("a", mu=mu, sigma=1)
488+
489+
ppc = pm.sample_posterior_predictive(
490+
trace=trace, model=model_ppc, return_inferencedata=False
491+
)
492+
assert "a" in ppc
493+
484494
with model:
485495
# test list input
486496
ppc0 = pm.sample_posterior_predictive(

0 commit comments

Comments
 (0)