@@ -26,9 +26,8 @@ function TracedModel(
26
26
" Sampling with `$(sampler. alg) ` does not support models with keyword arguments. See issue #2007 for more details." ,
27
27
)
28
28
end
29
- return TracedModel {AbstractSampler,AbstractVarInfo,Model,Tuple} (
30
- spl_model, sampler, varinfo, (spl_model. f, args... )
31
- )
29
+ evaluator = (spl_model. f, args... )
30
+ return TracedModel (spl_model, sampler, varinfo, evaluator)
32
31
end
33
32
34
33
function AdvancedPS. advance! (
@@ -60,20 +59,10 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
60
59
return Accessors. @set trace. model. varinfo = DynamicPPL. resetlogp!! (trace. model. varinfo)
61
60
end
62
61
63
- function AdvancedPS. update_rng! (
64
- trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
65
- )
66
- # Extract the `args`.
67
- args = trace. model. ctask. args
68
- # From `args`, extract the `SamplingContext`, which contains the RNG.
69
- sampling_context = args[3 ]
70
- rng = sampling_context. rng
71
- trace. rng = rng
72
- return trace
73
- end
74
-
75
- function Libtask. TapedTask (model:: TracedModel , :: Random.AbstractRNG , args... ; kwargs... ) # RNG ?
76
- return Libtask. TapedTask (model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs... )
62
+ function Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
63
+ return Libtask. TapedTask (
64
+ taped_globals, model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs...
65
+ )
77
66
end
78
67
79
68
abstract type ParticleInference <: InferenceAlgorithm end
@@ -403,11 +392,11 @@ end
403
392
404
393
function trace_local_varinfo_maybe (varinfo)
405
394
try
406
- trace = AdvancedPS . current_trace ()
407
- return trace. model. f. varinfo
395
+ trace = Libtask . get_taped_globals (Any) . other
396
+ return ( trace === nothing ? varinfo : trace . model. f. varinfo) :: AbstractVarInfo
408
397
catch e
409
398
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
410
- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
399
+ if e == KeyError (:task_variable )
411
400
return varinfo
412
401
else
413
402
rethrow (e)
@@ -417,11 +406,10 @@ end
417
406
418
407
function trace_local_rng_maybe (rng:: Random.AbstractRNG )
419
408
try
420
- trace = AdvancedPS. current_trace ()
421
- return trace. rng
409
+ return Libtask. get_taped_globals (Any). rng
422
410
catch e
423
411
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
424
- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
412
+ if e == KeyError (:task_variable )
425
413
return rng
426
414
else
427
415
rethrow (e)
@@ -487,6 +475,25 @@ function AdvancedPS.Trace(
487
475
488
476
tmodel = TracedModel (model, sampler, newvarinfo, rng)
489
477
newtrace = AdvancedPS. Trace (tmodel, rng)
490
- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
491
478
return newtrace
492
479
end
480
+
481
+ # We need to tell Libtask which calls may have `produce` calls within them. In practice most
482
+ # of these won't be needed, because of inlining and the fact that `might_produce` is only
483
+ # called on `:invoke` expressions rather than `:call`s, but since those are implementation
484
+ # details of the compiler, we set a bunch of methods as might_produce = true. We start with
485
+ # `acclogp_observe!!` which is what calls `produce` and go up the call stack.
486
+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}} ) = true
487
+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} ) = true
488
+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}} ) = true
489
+ function Libtask. might_produce (
490
+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
491
+ )
492
+ return true
493
+ end
494
+ function Libtask. might_produce (
495
+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
496
+ )
497
+ return true
498
+ end
499
+ Libtask. might_produce (:: Type{<:Tuple{<:DynamicPPL.Model,Vararg}} ) = true
0 commit comments