@@ -24,12 +24,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
2424function AdvancedPS. LibtaskModel (
2525 f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
2626) # Changed the API, need to take care of the RNG properly
27- return AdvancedPS. LibtaskModel (
28- f,
29- Libtask. TapedTask (
30- f, rng, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}
31- ),
32- )
27+ return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (rng, f, args... ))
3328end
3429
3530"""
5146
5247# step to the next observe statement and
5348# return the log probability of the transition (or nothing if done)
54- function AdvancedPS. advance! (t:: LibtaskTrace , isref:: Bool = false )
55- isref ? AdvancedPS. load_state! (t. rng) : AdvancedPS. save_state! (t. rng)
56- AdvancedPS. inc_counter! (t. rng)
49+ function AdvancedPS. advance! (trace:: LibtaskTrace , isref:: Bool = false )
50+ # Where is the RNG ?
51+ # isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.model.ctask.dynamic_scope) # Nasty
52+ isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
53+ AdvancedPS. inc_counter! (trace. rng)
54+
55+ Libtask. set_dynamic_scope! (trace. model. ctask, trace. rng)
5756
5857 # Move to next step
59- return Libtask. consume (t . model. ctask)
58+ return Libtask. consume (trace . model. ctask)
6059end
6160
6261# create a backward reference in task_local_storage
@@ -70,8 +69,9 @@ function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
7069end
7170
7271function AdvancedPS. update_rng! (trace:: LibtaskTrace )
73- rng, = trace. model. ctask. args
74- trace. rng = rng
72+ new_rng = deepcopy (trace. rng)
73+ trace. rng = new_rng
74+ Libtask. set_dynamic_scope! (trace. model. ctask, trace. rng)
7575 return trace
7676end
7777
@@ -81,27 +81,23 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
8181 AdvancedPS. update_rng! (newtrace)
8282 isref && AdvancedPS. delete_retained! (newtrace. model. f)
8383 isref && delete_seeds! (newtrace)
84-
85- # add backward reference
86- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
8784 return newtrace
8885end
8986
9087# PG requires keeping all randomness for the reference particle
9188# Create new task and copy randomness
9289function AdvancedPS. forkr (trace:: LibtaskTrace )
93- newf = AdvancedPS. reset_model (trace. model. f )
90+ newf = AdvancedPS. reset_model (trace. model. ctask . fargs[ 1 ] )
9491 Random123. set_counter! (trace. rng, 1 )
9592
96- ctask = Libtask. TapedTask (
97- newf, trace. rng; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (trace. model. f)}
98- )
93+ ctask = Libtask. TapedTask (trace. rng, newf)
9994 new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
10095
10196 # add backward reference
10297 newtrace = AdvancedPS. Trace (new_tapedmodel, trace. rng)
103- AdvancedPS. addreference! (ctask. task, newtrace)
10498 AdvancedPS. gen_refseed! (newtrace)
99+
100+ Libtask. set_dynamic_scope! (ctask, trace. rng) # Sync trace and rng
105101 return newtrace
106102end
107103
@@ -117,7 +113,7 @@ function AdvancedPS.observe(dist::Distributions.Distribution, x)
117113end
118114
119115"""
120- AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModel
116+ AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModelext
121117"""
122118
123119function AbstractMCMC. step (
@@ -138,7 +134,6 @@ function AbstractMCMC.step(
138134 else
139135 trng = AdvancedPS. TracedRNG ()
140136 trace = AdvancedPS. Trace (deepcopy (model), trng)
141- AdvancedPS. addreference! (trace. model. ctask. task, trace) # TODO : Do we need it here ?
142137 trace
143138 end
144139 end
@@ -176,7 +171,6 @@ function AbstractMCMC.sample(
176171 traces = map (1 : (sampler. nparticles)) do i
177172 trng = AdvancedPS. TracedRNG ()
178173 trace = AdvancedPS. Trace (deepcopy (model), trng)
179- AdvancedPS. addreference! (trace. model. ctask. task, trace) # Do we need it here ?
180174 trace
181175 end
182176
0 commit comments