File tree Expand file tree Collapse file tree 2 files changed +14
-0
lines changed
Expand file tree Collapse file tree 2 files changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -767,13 +767,15 @@ def sample_posterior_predictive(
767767 if "coords" not in idata_kwargs :
768768 idata_kwargs ["coords" ] = {}
769769 idata : InferenceData | None = None
770+ observed_data = None
770771 stacked_dims = None
771772 if isinstance (trace , InferenceData ):
772773 _constant_data = getattr (trace , "constant_data" , None )
773774 if _constant_data is not None :
774775 trace_coords .update ({str (k ): v .data for k , v in _constant_data .coords .items ()})
775776 constant_data .update ({str (k ): v .data for k , v in _constant_data .items ()})
776777 idata = trace
778+ observed_data = trace ["observed_data" ]
777779 trace = trace ["posterior" ]
778780 if isinstance (trace , xarray .Dataset ):
779781 trace_coords .update ({str (k ): v .data for k , v in trace .coords .items ()})
@@ -817,6 +819,8 @@ def sample_posterior_predictive(
817819 vars_ = [model [x ] for x in var_names ]
818820 else :
819821 vars_ = model .observed_RVs + observed_dependent_deterministics (model )
822+ if observed_data is not None :
823+ vars_ += [model [x ] for x in observed_data if x in model ]
820824
821825 vars_to_sample = list (get_default_varnames (vars_ , include_transformed = False ))
822826
Original file line number Diff line number Diff line change @@ -481,6 +481,16 @@ def test_normal_scalar(self):
481481 chains = nchains ,
482482 )
483483
484+ # test that trace is used in ppc
485+ with pm .Model () as model_ppc :
486+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
487+ a = pm .Normal ("a" , mu = mu , sigma = 1 )
488+
489+ ppc = pm .sample_posterior_predictive (
490+ trace = trace , model = model_ppc , return_inferencedata = False
491+ )
492+ assert "a" in ppc
493+
484494 with model :
485495 # test list input
486496 ppc0 = pm .sample_posterior_predictive (
You can’t perform that action at this time.
0 commit comments