Skip to content

Commit b3450d3

Browse files
committed
Make Turing tests compatible with Libtask 0.4
1 parent 180458e commit b3450d3

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

test/Turing/core/container.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model}
22
model::Tmodel
33
spl::Tspl
44
vi::Tvi
5-
task::Task
5+
ctask::CTask
66

77
function Trace{SampleFromPrior}(model::Model, spl::AbstractSampler, vi::AbstractVarInfo)
88
return new{SampleFromPrior,typeof(vi),typeof(model)}(model, SampleFromPrior(), vi)
@@ -15,48 +15,50 @@ end
1515
function Base.copy(trace::Trace)
1616
vi = deepcopy(trace.vi)
1717
res = Trace{typeof(trace.spl)}(trace.model, trace.spl, vi)
18-
res.task = copy(trace.task)
18+
res.ctask = copy(trace.ctask)
1919
return res
2020
end
2121

2222
# NOTE: this function is called by `forkr`
2323
function Trace(f::Function, m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
2424
res = Trace{typeof(spl)}(m, spl, deepcopy(vi));
25-
# CTask(()->f());
26-
res.task = CTask( () -> begin res=f(); produce(Val{:done}); res; end )
27-
if res.task.storage === nothing
28-
res.task.storage = IdDict()
25+
ctask = CTask(() -> (res = f(); produce(Val{:done}); res))
26+
task = ctask.task
27+
if task.storage === nothing
28+
task.storage = IdDict()
2929
end
30-
res.task.storage[:turing_trace] = res # create a backward reference in task_local_storage
30+
task.storage[:turing_trace] = res # create a backward reference in task_local_storage
31+
res.ctask = ctask
3132
return res
3233
end
3334
function Trace(m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
3435
res = Trace{typeof(spl)}(m, spl, deepcopy(vi));
35-
# CTask(()->f());
3636
reset_num_produce!(res.vi)
37-
res.task = CTask( () -> begin vi_new=m(vi, spl); produce(Val{:done}); vi_new; end )
38-
if res.task.storage === nothing
39-
res.task.storage = IdDict()
37+
ctask = CTask(() -> (vi_new = m(vi, spl); produce(Val{:done}); vi_new))
38+
task = ctask.task
39+
if task.storage === nothing
40+
task.storage = IdDict()
4041
end
41-
res.task.storage[:turing_trace] = res # create a backward reference in task_local_storage
42+
task.storage[:turing_trace] = res # create a backward reference in task_local_storage
43+
res.ctask = ctask
4244
return res
4345
end
4446

4547
# step to the next observe statement, return log likelihood
46-
Libtask.consume(t::Trace) = (increment_num_produce!(t.vi); consume(t.task))
48+
Libtask.consume(t::Trace) = (increment_num_produce!(t.vi); consume(t.ctask))
4749

4850
# Task copying version of fork for Trace.
4951
function fork(trace :: Trace, is_ref :: Bool = false)
5052
newtrace = copy(trace)
5153
is_ref && set_retained_vns_del_by_spl!(newtrace.vi, newtrace.spl)
52-
newtrace.task.storage[:turing_trace] = newtrace
54+
newtrace.ctask.task.storage[:turing_trace] = newtrace
5355
return newtrace
5456
end
5557

5658
# PG requires keeping all randomness for the reference particle
5759
# Create new task and copy randomness
58-
function forkr(trace :: Trace)
59-
newtrace = Trace(trace.task.code, trace.model, trace.spl, deepcopy(trace.vi))
60+
function forkr(trace::Trace)
61+
newtrace = Trace(trace.ctask.task.code, trace.model, trace.spl, deepcopy(trace.vi))
6062
newtrace.spl = trace.spl
6163
reset_num_produce!(newtrace.vi)
6264
return newtrace

0 commit comments

Comments
 (0)