Skip to content

Commit 88e9dad

Browse files
Rik Huijzeryebai
andauthored
Switch to Libtask.TapedTask (#43)
* Switch to Libtask.TapedTask * Rename `ttask` to `task` * Update Project.toml Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 44f5d1a commit 88e9dad

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.3.0"
4+
version = "0.4"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/container.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
struct Trace{F,R<:TracedRNG}
22
f::F
3-
ctask::Libtask.CTask
3+
task::Libtask.TapedTask
44
rng::R
55
end
66

77
const Particle = Trace
88

99
function Trace(f, rng::TracedRNG)
10-
ctask = Libtask.CTask(f, rng)
10+
task = Libtask.TapedTask(f, rng)
1111

1212
# add backward reference
13-
newtrace = Trace(f, ctask, rng)
14-
addreference!(ctask.task, newtrace)
13+
newtrace = Trace(f, task, rng)
14+
addreference!(task.task, newtrace)
1515

1616
return newtrace
1717
end
1818

19-
function Trace(f, ctask::Libtask.CTask)
20-
return Trace(f, ctask, TracedRNG())
19+
function Trace(f, task::Libtask.TapedTask)
20+
return Trace(f, task, TracedRNG())
2121
end
2222

2323
# Copy task
24-
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng))
24+
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.task), deepcopy(trace.rng))
2525

2626
# step to the next observe statement and
2727
# return the log probability of the transition (or nothing if done)
@@ -30,7 +30,7 @@ function advance!(t::Trace, isref::Bool)
3030
inc_counter!(t.rng)
3131

3232
# Move to next step
33-
return Libtask.consume(t.ctask)
33+
return Libtask.consume(t.task)
3434
end
3535

3636
# reset log probability
@@ -45,7 +45,7 @@ function fork(trace::Trace, isref::Bool=false)
4545
isref && delete_retained!(newtrace.f)
4646

4747
# add backward reference
48-
addreference!(newtrace.ctask.task, newtrace)
48+
addreference!(newtrace.task.task, newtrace)
4949

5050
return newtrace
5151
end
@@ -56,11 +56,11 @@ function forkr(trace::Trace)
5656
newf = reset_model(trace.f)
5757
Random123.set_counter!(trace.rng, 1)
5858

59-
ctask = Libtask.CTask(newf, trace.rng)
59+
task = Libtask.TapedTask(newf, trace.rng)
6060

6161
# add backward reference
62-
newtrace = Trace(newf, ctask, trace.rng)
63-
addreference!(ctask.task, newtrace)
62+
newtrace = Trace(newf, task, trace.rng)
63+
addreference!(task.task, newtrace)
6464

6565
return newtrace
6666
end

test/container.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@
111111
# Test task copy version of trace
112112
tr = AdvancedPS.Trace(f2, AdvancedPS.TracedRNG())
113113

114-
consume(tr.ctask)
115-
consume(tr.ctask)
114+
consume(tr.task)
115+
consume(tr.task)
116116

117117
a = AdvancedPS.fork(tr)
118-
consume(a.ctask)
119-
consume(a.ctask)
118+
consume(a.task)
119+
consume(a.task)
120120

121-
@test consume(tr.ctask) == 2
122-
@test consume(a.ctask) == 4
121+
@test consume(tr.task) == 2
122+
@test consume(a.task) == 4
123123
end
124124

125125
@testset "seed container" begin

0 commit comments

Comments
 (0)