File tree Expand file tree Collapse file tree 2 files changed +8
-5
lines changed
Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -345,10 +345,13 @@ def draw(
345345 return [np .stack (v ) for v in drawn_values ]
346346
347347
348- def observed_dependent_deterministics (model : Model ):
348+ def observed_dependent_deterministics (model : Model , extra_observeds = None ):
349349 """Find deterministics that depend directly on observed variables."""
350+ if extra_observeds is None :
351+ extra_observeds = []
352+
350353 deterministics = model .deterministics
351- observed_rvs = set (model .observed_RVs )
354+ observed_rvs = set (model .observed_RVs + extra_observeds )
352355 blockers = model .basic_RVs
353356 return [
354357 deterministic
@@ -821,6 +824,7 @@ def sample_posterior_predictive(
821824 vars_ = model .observed_RVs + observed_dependent_deterministics (model )
822825 if observed_data is not None :
823826 vars_ += [model [x ] for x in observed_data if x in model and x not in vars_ ]
827+ vars_ += observed_dependent_deterministics (model , vars_ )
824828
825829 vars_to_sample = list (get_default_varnames (vars_ , include_transformed = False ))
826830
Original file line number Diff line number Diff line change @@ -561,7 +561,6 @@ def test_external_trace(self):
561561 )
562562 assert list (ppc .keys ()) == ["a" ]
563563
564- @pytest .mark .xfail (reason = "Auto-imputation of variables not supported in this setting" )
565564 def test_external_trace_det (self ):
566565 nchains = 2
567566 ndraws = 500
@@ -578,12 +577,12 @@ def test_external_trace_det(self):
578577 with pm .Model () as model_ppc :
579578 mu = pm .Normal ("mu" , 0.0 , 1.0 )
580579 a = pm .Normal ("a" , mu = mu , sigma = 1 )
581- b = pm .Deterministic ("b " , a + 1 )
580+ c = pm .Deterministic ("c " , a + 1 )
582581
583582 ppc = pm .sample_posterior_predictive (
584583 trace = trace , model = model_ppc , return_inferencedata = False
585584 )
586- assert list (ppc .keys ()) == ["a" , "b " ]
585+ assert list (ppc .keys ()) == ["a" , "c " ]
587586
588587 def test_normal_vector (self ):
589588 with pm .Model () as model :
You can’t perform that action at this time.
0 commit comments