1919"""
2020 LibtaskModel{F}
2121
22- State wrapper to hold `Libtask.CTask` model initiated from `f`
22+ State wrapper to hold `Libtask.CTask` model initiated from `f`.
2323"""
24- struct LibtaskModel{F1,F2}
25- f:: F1
26- ctask:: Libtask.TapedTask{F2}
27-
28- LibtaskModel (f:: F1 , ctask:: Libtask.TapedTask{F2} ) where {F1,F2} = new {F1,F2} (f, ctask)
29- end
30-
31- function LibtaskModel (f, args... )
32- return LibtaskModel (
24+ function AdvancedPS. LibtaskModel (
25+ f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
26+ ) # Changed the API, need to take care of the RNG properly
27+ return AdvancedPS. LibtaskModel (
3328 f,
34- Libtask. TapedTask (f, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}),
29+ Libtask. TapedTask (
30+ f, rng, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}
31+ ),
3532 )
3633end
3734
38- Base. copy (model:: LibtaskModel ) = LibtaskModel (model. f, copy (model. ctask))
35+ function Base. copy (model:: AdvancedPS.LibtaskModel )
36+ return AdvancedPS. LibtaskModel (model. f, copy (model. ctask))
37+ end
3938
40- const LibtaskTrace{R} = AdvancedPS. Trace{<: LibtaskModel ,R}
39+ const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS. LibtaskModel ,R}
4140
4241function AdvancedPS. Trace (
4342 model:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
4443)
45- return AdvancedPS. Trace (LibtaskModel (model, args... ), rng)
44+ return AdvancedPS. Trace (AdvancedPS . LibtaskModel (model, rng , args... ), rng)
4645end
4746
4847# step to the next observe statement and
@@ -56,7 +55,7 @@ function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
5655end
5756
5857# create a backward reference in task_local_storage
59- function addreference! (task:: Task , trace:: LibtaskTrace )
58+ function AdvancedPS . addreference! (task:: Task , trace:: LibtaskTrace )
6059 if task. storage === nothing
6160 task. storage = IdDict ()
6261 end
@@ -65,9 +64,7 @@ function addreference!(task::Task, trace::LibtaskTrace)
6564 return task
6665end
6766
68- current_trace () = current_task (). storage[:__trace ]
69-
70- function update_rng! (trace:: LibtaskTrace )
67+ function AdvancedPS. update_rng! (trace:: LibtaskTrace )
7168 rng, = trace. model. ctask. args
7269 trace. rng = rng
7370 return trace
7673# Task copying version of fork for Trace.
7774function AdvancedPS. fork (trace:: LibtaskTrace , isref:: Bool = false )
7875 newtrace = copy (trace)
79- update_rng! (newtrace)
76+ AdvancedPS . update_rng! (newtrace)
8077 isref && AdvancedPS. delete_retained! (newtrace. model. f)
8178 isref && delete_seeds! (newtrace)
8279
8380 # add backward reference
84- addreference! (newtrace. model. ctask. task, newtrace)
81+ AdvancedPS . addreference! (newtrace. model. ctask. task, newtrace)
8582 return newtrace
8683end
8784
@@ -94,11 +91,11 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
9491 ctask = Libtask. TapedTask (
9592 newf, trace. rng; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (trace. model. f)}
9693 )
97- new_tapedmodel = LibtaskModel (newf, ctask)
94+ new_tapedmodel = AdvancedPS . LibtaskModel (newf, ctask)
9895
9996 # add backward reference
10097 newtrace = AdvancedPS. Trace (new_tapedmodel, trace. rng)
101- addreference! (ctask. task, newtrace)
98+ AdvancedPS . addreference! (ctask. task, newtrace)
10299 AdvancedPS. gen_refseed! (newtrace)
103100 return newtrace
104101end
@@ -135,9 +132,8 @@ function AbstractMCMC.step(
135132 AdvancedPS. forkr (copy (state. trajectory))
136133 else
137134 trng = AdvancedPS. TracedRNG ()
138- gen_model = LibtaskModel (deepcopy (model), trng)
139- trace = AdvancedPS. Trace (LibtaskModel (deepcopy (model), trng), trng)
140- addreference! (gen_model. ctask. task, trace) # Do we need it here ?
135+ trace = AdvancedPS. Trace (deepcopy (model), trng)
136+ AdvancedPS. addreference! (trace. model. ctask. task, trace) # TODO : Do we need it here ?
141137 trace
142138 end
143139 end
@@ -174,9 +170,8 @@ function AbstractMCMC.sample(
174170
175171 traces = map (1 : (sampler. nparticles)) do i
176172 trng = AdvancedPS. TracedRNG ()
177- gen_model = LibtaskModel (deepcopy (model), trng)
178- trace = AdvancedPS. Trace (LibtaskModel (deepcopy (model), trng), trng)
179- addreference! (gen_model. ctask. task, trace) # Do we need it here ?
173+ trace = AdvancedPS. Trace (deepcopy (model), trng)
174+ AdvancedPS. addreference! (trace. model. ctask. task, trace) # Do we need it here ?
180175 trace
181176 end
182177
@@ -202,7 +197,9 @@ function AdvancedPS.replay(particle::AdvancedPS.Particle)
202197 trng = deepcopy (particle. rng)
203198 Random123. set_counter! (trng. rng, 0 )
204199 trng. count = 1
205- trace = AdvancedPS. Trace (LibtaskModel (deepcopy (particle. model. f), trng), trng)
200+ trace = AdvancedPS. Trace (
201+ AdvancedPS. LibtaskModel (deepcopy (particle. model. f), trng), trng
202+ )
206203 score = AdvancedPS. advance! (trace, true )
207204 while ! isnothing (score)
208205 score = AdvancedPS. advance! (trace, true )
0 commit comments