diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index b506aa7ba..7024dca4a 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -825,7 +825,7 @@ def sample_posterior_predictive( if return_inferencedata and not extend_inferencedata: return InferenceData() elif return_inferencedata and extend_inferencedata: - return trace + return trace if idata is None else idata return {} vars_in_trace = get_vars_in_point_list(_trace, model) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 24579bae0..8563ff720 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -492,6 +492,11 @@ def test_normal_scalar(self): ppc = pm.sample_posterior_predictive(trace, var_names=[], return_inferencedata=False) assert len(ppc) == 0 + # test empty ppc with extend_inferencedata + assert isinstance(trace, InferenceData) + ppc = pm.sample_posterior_predictive(trace, var_names=[], extend_inferencedata=True) + assert ppc is trace + # test keep_size parameter ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False) assert ppc["a"].shape == (nchains, ndraws)