@@ -481,16 +481,6 @@ def test_normal_scalar(self):
481481 chains = nchains ,
482482 )
483483
484- # test that trace is used in ppc
485- with pm .Model () as model_ppc :
486- mu = pm .Normal ("mu" , 0.0 , 1.0 )
487- a = pm .Normal ("a" , mu = mu , sigma = 1 )
488-
489- ppc = pm .sample_posterior_predictive (
490- trace = trace , model = model_ppc , return_inferencedata = False
491- )
492- assert "a" in ppc
493-
494484 with model :
495485 # test list input
496486 ppc0 = pm .sample_posterior_predictive (
@@ -550,6 +540,51 @@ def test_normal_scalar_idata(self):
550540 ppc = pm .sample_posterior_predictive (idata , return_inferencedata = False )
551541 assert ppc ["a" ].shape == (nchains , ndraws )
552542
543+ def test_external_trace (self ):
544+ nchains = 2
545+ ndraws = 500
546+ with pm .Model () as model :
547+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
548+ a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = 0.0 )
549+ trace = pm .sample (
550+ draws = ndraws ,
551+ chains = nchains ,
552+ )
553+
554+ # test that trace is used in ppc
555+ with pm .Model () as model_ppc :
556+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
557+ a = pm .Normal ("a" , mu = mu , sigma = 1 )
558+
559+ ppc = pm .sample_posterior_predictive (
560+ trace = trace , model = model_ppc , return_inferencedata = False
561+ )
562+ assert list (ppc .keys ()) == ["a" ]
563+
564+ @pytest .mark .xfail (reason = "Auto-imputation of variables not supported in this setting" )
565+ def test_external_trace_det (self ):
566+ nchains = 2
567+ ndraws = 500
568+ with pm .Model () as model :
569+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
570+ a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = 0.0 )
571+ b = pm .Deterministic ("b" , a + 1 )
572+ trace = pm .sample (
573+ draws = ndraws ,
574+ chains = nchains ,
575+ )
576+
577+ # test that trace is used in ppc
578+ with pm .Model () as model_ppc :
579+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
580+ a = pm .Normal ("a" , mu = mu , sigma = 1 )
581+ b = pm .Deterministic ("b" , a + 1 )
582+
583+ ppc = pm .sample_posterior_predictive (
584+ trace = trace , model = model_ppc , return_inferencedata = False
585+ )
586+ assert list (ppc .keys ()) == ["a" , "b" ]
587+
553588 def test_normal_vector (self ):
554589 with pm .Model () as model :
555590 mu = pm .Normal ("mu" , 0.0 , 1.0 )
0 commit comments