@@ -2,7 +2,7 @@ mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model}
2
2
model:: Tmodel
3
3
spl:: Tspl
4
4
vi:: Tvi
5
- task :: Task
5
+ ctask :: CTask
6
6
7
7
function Trace {SampleFromPrior} (model:: Model , spl:: AbstractSampler , vi:: AbstractVarInfo )
8
8
return new {SampleFromPrior,typeof(vi),typeof(model)} (model, SampleFromPrior (), vi)
15
15
function Base. copy (trace:: Trace )
16
16
vi = deepcopy (trace. vi)
17
17
res = Trace {typeof(trace.spl)} (trace. model, trace. spl, vi)
18
- res. task = copy (trace. task )
18
+ res. ctask = copy (trace. ctask )
19
19
return res
20
20
end
21
21
22
22
# NOTE: this function is called by `forkr`
23
23
function Trace (f:: Function , m:: Model , spl:: AbstractSampler , vi:: AbstractVarInfo )
24
24
res = Trace {typeof(spl)} (m, spl, deepcopy (vi));
25
- # CTask(()-> f());
26
- res . task = CTask ( () -> begin res = f (); produce (Val{ :done }); res; end )
27
- if res . task. storage === nothing
28
- res . task. storage = IdDict ()
25
+ ctask = CTask (() -> (res = f (); produce (Val{ :done }); res))
26
+ task = ctask . task
27
+ if task. storage === nothing
28
+ task. storage = IdDict ()
29
29
end
30
- res. task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
30
+ task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
31
+ res. ctask = ctask
31
32
return res
32
33
end
33
34
function Trace (m:: Model , spl:: AbstractSampler , vi:: AbstractVarInfo )
34
35
res = Trace {typeof(spl)} (m, spl, deepcopy (vi));
35
- # CTask(()->f());
36
36
reset_num_produce! (res. vi)
37
- res. task = CTask ( () -> begin vi_new= m (vi, spl); produce (Val{:done }); vi_new; end )
38
- if res. task. storage === nothing
39
- res. task. storage = IdDict ()
37
+ ctask = CTask (() -> (vi_new = m (vi, spl); produce (Val{:done }); vi_new))
38
+ task = ctask. task
39
+ if task. storage === nothing
40
+ task. storage = IdDict ()
40
41
end
41
- res. task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
42
+ task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
43
+ res. ctask = ctask
42
44
return res
43
45
end
44
46
45
47
# step to the next observe statement, return log likelihood
46
- Libtask. consume (t:: Trace ) = (increment_num_produce! (t. vi); consume (t. task ))
48
+ Libtask. consume (t:: Trace ) = (increment_num_produce! (t. vi); consume (t. ctask ))
47
49
48
50
# Task copying version of fork for Trace.
49
51
function fork (trace :: Trace , is_ref :: Bool = false )
50
52
newtrace = copy (trace)
51
53
is_ref && set_retained_vns_del_by_spl! (newtrace. vi, newtrace. spl)
52
- newtrace. task. storage[:turing_trace ] = newtrace
54
+ newtrace. ctask . task. storage[:turing_trace ] = newtrace
53
55
return newtrace
54
56
end
55
57
56
58
# PG requires keeping all randomness for the reference particle
57
59
# Create new task and copy randomness
58
- function forkr (trace :: Trace )
59
- newtrace = Trace (trace. task. code, trace. model, trace. spl, deepcopy (trace. vi))
60
+ function forkr (trace:: Trace )
61
+ newtrace = Trace (trace. ctask . task. code, trace. model, trace. spl, deepcopy (trace. vi))
60
62
newtrace. spl = trace. spl
61
63
reset_num_produce! (newtrace. vi)
62
64
return newtrace
0 commit comments