1717end
1818
1919# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20- # and sometimes both an RNG and other information. In Turing.jl this other information
20+ # and sometimes both an RNG and other information. In Turing.jl the other information
2121# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
2222# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
2323struct TapedGlobals{RngType}
4949
5050const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
5151
52+ """ Get the RNG from a `LibtaskTrace`."""
53+ function get_rng (trace:: LibtaskTrace )
54+ return trace. model. ctask. taped_globals. rng
55+ end
56+
57+ """ Set the RNG for a `LibtaskTrace`."""
58+ function set_rng! (trace:: LibtaskTrace , rng:: Random.AbstractRNG )
59+ taped_globals = trace. model. ctask. taped_globals
60+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, taped_globals. other))
61+ trace. rng = rng
62+ return trace
63+ end
64+
65+ """ Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
66+ function set_other_global! (trace:: LibtaskTrace , other)
67+ rng = get_rng (trace)
68+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, other))
69+ return trace
70+ end
71+
72+ """ Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
73+ get_other_global (trace:: LibtaskTrace ) = trace. model. ctask. taped_globals. other
74+
5275function AdvancedPS. Trace (
5376 model:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
5477)
5881# step to the next observe statement and
5982# return the log probability of the transition (or nothing if done)
6083function AdvancedPS. advance! (trace:: LibtaskTrace , isref:: Bool = false )
61- taped_globals = trace. model. ctask. taped_globals
62- rng = taped_globals. rng
84+ rng = get_rng (trace)
6385 isref ? AdvancedPS. load_state! (rng) : AdvancedPS. save_state! (rng)
6486 AdvancedPS. inc_counter! (rng)
65-
66- Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, taped_globals. other))
67- trace. rng = rng
68-
87+ set_rng! (trace, rng)
6988 # Move to next step
7089 return Libtask. consume (trace. model. ctask)
7190end
7291
73- # create a backward reference in task_local_storage
74- function AdvancedPS. addreference! (task:: Libtask.TapedTask , trace:: LibtaskTrace )
75- rng = task. taped_globals. rng
76- Libtask. set_taped_globals! (task, TapedGlobals (rng, trace))
77- return task
92+ """
93+ Set a backreference so that the TapedTask in `trace` stores the `trace` itself in the
94+ taped globals.
95+ """
96+ function AdvancedPS. addreference! (trace:: LibtaskTrace )
97+ set_other_global! (trace, trace)
98+ return trace
7899end
79100
80101function AdvancedPS. update_rng! (trace:: LibtaskTrace )
81- taped_globals = trace. model. ctask. taped_globals
82- new_rng = deepcopy (taped_globals. rng)
83- trace. rng = new_rng
84- Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (new_rng, taped_globals. other))
102+ set_rng! (trace, deepcopy (get_rng (trace)))
85103 return trace
86104end
87105
@@ -91,19 +109,19 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
91109 AdvancedPS. update_rng! (newtrace)
92110 isref && AdvancedPS. delete_retained! (newtrace. model. f)
93111 isref && delete_seeds! (newtrace)
112+ AdvancedPS. addreference! (newtrace)
94113 return newtrace
95114end
96115
97116# PG requires keeping all randomness for the reference particle
98117# Create new task and copy randomness
99118function AdvancedPS. forkr (trace:: LibtaskTrace )
100- taped_globals = trace. model. ctask. taped_globals
101- rng = taped_globals. rng
119+ rng = get_rng (trace)
102120 newf = AdvancedPS. reset_model (trace. model. f)
103121 Random123. set_counter! (rng, 1 )
104122 trace. rng = rng
105123
106- ctask = Libtask. TapedTask (TapedGlobals (rng, taped_globals . other ), newf)
124+ ctask = Libtask. TapedTask (TapedGlobals (rng, get_other_global (trace) ), newf)
107125 new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
108126
109127 # add backward reference
0 commit comments