1616 using .. Libtask: Libtask
1717end
1818
19+ # In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20+ # and sometimes both an RNG and other information. In Turing.jl the other information
21+ # is a VarInfo. This struct puts those in a single struct. Note the abstract type of
22+ # the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
23+ struct TapedGlobals{RngType}
24+ rng:: RngType
25+ other:: Any
26+ end
27+
28+ TapedGlobals(rng:: Random.AbstractRNG ) = TapedGlobals(rng, nothing )
29+
1930"""
2031 LibtaskModel{F}
2132
@@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
2435function AdvancedPS. LibtaskModel(
2536 f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
2637) # Changed the API, need to take care of the RNG properly
27- return AdvancedPS. LibtaskModel(
28- f,
29- Libtask. TapedTask(
30- f, rng, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof(f)}
31- ),
32- )
38+ return AdvancedPS. LibtaskModel(f, Libtask. TapedTask(TapedGlobals(rng), f, args... ))
3339end
3440
3541"""
4349
4450const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
4551
46- function AdvancedPS . Trace(
47- model :: AdvancedPS.AbstractGenericModel , rng :: Random.AbstractRNG , args ...
48- )
49- return AdvancedPS . Trace(AdvancedPS . LibtaskModel(model, rng, args ... ), rng)
52+ function Base . copy(trace :: LibtaskTrace )
53+ newtrace = AdvancedPS. Trace(copy(trace . model), deepcopy(trace . rng))
54+ set_other_global!(newtrace, newtrace )
55+ return newtrace
5056end
5157
52- # step to the next observe statement and
53- # return the log probability of the transition (or nothing if done)
54- function AdvancedPS. advance!(t:: LibtaskTrace , isref:: Bool = false )
55- isref ? AdvancedPS. load_state!(t. rng) : AdvancedPS. save_state!(t. rng)
56- AdvancedPS. inc_counter!(t. rng)
57-
58- # Move to next step
59- return Libtask. consume(t. model. ctask)
58+ """ Get the RNG from a `LibtaskTrace`."""
59+ function get_rng(trace:: LibtaskTrace )
60+ return trace. model. ctask. taped_globals. rng
6061end
6162
62- # create a backward reference in task_local_storage
63- function AdvancedPS. addreference!(task:: Task , trace:: LibtaskTrace )
64- if task. storage === nothing
65- task. storage = IdDict()
66- end
67- task. storage[:__trace] = trace
63+ """ Set the RNG for a `LibtaskTrace`."""
64+ function set_rng!(trace:: LibtaskTrace , rng:: Random.AbstractRNG )
65+ other = get_other_global(trace)
66+ Libtask. set_taped_globals!(trace. model. ctask, TapedGlobals(rng, other))
67+ trace. rng = rng
68+ return trace
69+ end
6870
69- return task
71+ """ Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
72+ function set_other_global!(trace:: LibtaskTrace , other)
73+ rng = get_rng(trace)
74+ Libtask. set_taped_globals!(trace. model. ctask, TapedGlobals(rng, other))
75+ return trace
7076end
7177
72- function AdvancedPS. update_rng!(trace:: LibtaskTrace )
73- rng, = trace. model. ctask. args
74- trace. rng = rng
78+ """ Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
79+ get_other_global(trace:: LibtaskTrace ) = trace. model. ctask. taped_globals. other
80+
81+ function AdvancedPS. Trace(
82+ model:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
83+ )
84+ trace = AdvancedPS. Trace(AdvancedPS. LibtaskModel(model, rng, args... ), rng)
85+ # Set a backreference so that the TapedTask in `trace` stores the `trace` itself in its
86+ # taped globals.
87+ set_other_global!(trace, trace)
7588 return trace
7689end
7790
91+ # step to the next observe statement and
92+ # return the log probability of the transition (or nothing if done)
93+ function AdvancedPS. advance!(trace:: LibtaskTrace , isref:: Bool = false )
94+ rng = get_rng(trace)
95+ isref ? AdvancedPS. load_state!(rng) : AdvancedPS. save_state!(rng)
96+ AdvancedPS. inc_counter!(rng)
97+ # Move to next step
98+ return Libtask. consume(trace. model. ctask)
99+ end
100+
78101# Task copying version of fork for Trace.
79102function AdvancedPS. fork(trace:: LibtaskTrace , isref:: Bool = false )
80103 newtrace = copy(trace)
81- AdvancedPS . update_rng !(newtrace)
104+ set_rng !(newtrace, deepcopy(get_rng(newtrace)) )
82105 isref && AdvancedPS. delete_retained!(newtrace. model. f)
83106 isref && delete_seeds!(newtrace)
84-
85- # add backward reference
86- AdvancedPS. addreference!(newtrace. model. ctask. task, newtrace)
87107 return newtrace
88108end
89109
90110# PG requires keeping all randomness for the reference particle
91111# Create new task and copy randomness
92112function AdvancedPS. forkr(trace:: LibtaskTrace )
113+ rng = get_rng(trace)
93114 newf = AdvancedPS. reset_model(trace. model. f)
94- Random123. set_counter!(trace . rng, 1 )
115+ Random123. set_counter!(rng, 1 )
95116
96- ctask = Libtask. TapedTask(
97- newf, trace. rng; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof(trace. model. f)}
98- )
117+ ctask = Libtask. TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
99118 new_tapedmodel = AdvancedPS. LibtaskModel(newf, ctask)
100119
101120 # add backward reference
102- newtrace = AdvancedPS. Trace(new_tapedmodel, trace. rng)
103- AdvancedPS. addreference!(ctask. task, newtrace)
121+ newtrace = AdvancedPS. Trace(new_tapedmodel, rng)
104122 AdvancedPS. gen_refseed!(newtrace)
105123 return newtrace
106124end
@@ -113,7 +131,8 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
113131Observe sample `x` from distribution `dist` and yield its log-likelihood value.
114132"""
115133function AdvancedPS. observe(dist:: Distributions.Distribution , x)
116- return Libtask. produce(Distributions. loglikelihood(dist, x))
134+ Libtask. produce(Distributions. loglikelihood(dist, x))
135+ return nothing
117136end
118137
119138"""
@@ -138,7 +157,6 @@ function AbstractMCMC.step(
138157 else
139158 trng = AdvancedPS. TracedRNG()
140159 trace = AdvancedPS. Trace(deepcopy(model), trng)
141- AdvancedPS. addreference!(trace. model. ctask. task, trace) # TODO : Do we need it here ?
142160 trace
143161 end
144162 end
@@ -153,8 +171,7 @@ function AbstractMCMC.step(
153171 newtrajectory = rand(rng, particles)
154172
155173 replayed = AdvancedPS. replay(newtrajectory)
156- return AdvancedPS. PGSample(replayed. model. f, logevidence),
157- AdvancedPS. PGState(newtrajectory)
174+ return AdvancedPS. PGSample(replayed. model. f, logevidence), AdvancedPS. PGState(replayed)
158175end
159176
160177function AbstractMCMC. sample(
@@ -176,7 +193,6 @@ function AbstractMCMC.sample(
176193 traces = map(1 : (sampler. nparticles)) do i
177194 trng = AdvancedPS. TracedRNG()
178195 trace = AdvancedPS. Trace(deepcopy(model), trng)
179- AdvancedPS. addreference!(trace. model. ctask. task, trace) # Do we need it here ?
180196 trace
181197 end
182198
0 commit comments