@@ -12,9 +12,17 @@ using BenchmarkTools
1212 y ~ Normal (μ, sqrt (σ))
1313end
1414
15+ function make_sampling_model_and_args (model, rng, sampler)
16+ ctx = DynamicPPL. SamplingContext (rng, sampler, model. context)
17+ spl_model = DynamicPPL. contextualize (model, ctx)
18+ fargs, _ = DynamicPPL. make_evaluate_args_and_kwargs (spl_model, varinfo)
19+ return (spl_model, fargs)
20+ end
21+
1522# Case 1: Sample from the prior.
16- rng = MersenneTwister ()
17- m = Turing. Inference. TracedModel (gdemo (1.5 , 2.0 ), SampleFromPrior (), VarInfo (), rng)
23+ spl, rng = SampleFromPrior (), MersenneTwister ()
24+ spl_model, fargs = make_sampling_model (gdemo (1.5 , 2.0 ), rng, spl)
25+ m = Turing. Inference. TracedModel (spl_model, spl, VarInfo (), fargs)
1826f = m. evaluator[1 ];
1927args = m. evaluator[2 : end ];
2028
@@ -27,7 +35,9 @@ println("Run a tape...")
2735@btime t. tf (args... )
2836
2937# Case 2: SMC sampler
30- m = Turing. Inference. TracedModel (gdemo (1.5 , 2.0 ), Sampler (SMC (50 )), VarInfo (), rng)
38+ spl, rng = SMC (50 ), MersenneTwister ()
39+ spl_model, fargs = make_sampling_model (gdemo (1.5 , 2.0 ), rng, spl)
40+ m = Turing. Inference. TracedModel (spl_model, spl, VarInfo (), fargs)
3141f = m. evaluator[1 ];
3242args = m. evaluator[2 : end ];
3343
0 commit comments