Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ else
using ..Libtask: Libtask
end

# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
# and sometimes both an RNG and other information. In Turing.jl this other information
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
struct TapedGlobals{RngType}
rng::RngType
other::Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove addreference!, the field other will only store varinfo instances.

end

TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing)

"""
LibtaskModel{F}

Expand All @@ -24,7 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
function AdvancedPS.LibtaskModel(
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
) # Changed the API, need to take care of the RNG properly
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(rng, f, args...))
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
end

"""
Expand All @@ -47,30 +58,30 @@ end
# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
# Where is the RNG ?
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
AdvancedPS.inc_counter!(trace.rng)
taped_globals = trace.model.ctask.taped_globals
rng = taped_globals.rng
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
AdvancedPS.inc_counter!(rng)

Libtask.set_taped_globals!(trace.model.ctask, trace.rng)
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
trace.rng = rng

# Move to next step
return Libtask.consume(trace.model.ctask)
end

# create a backward reference in task_local_storage
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:__trace] = trace

function AdvancedPS.addreference!(task::Libtask.TapedTask, trace::LibtaskTrace)
rng = task.taped_globals.rng
Libtask.set_taped_globals!(task, TapedGlobals(rng, trace))
return task
end

function AdvancedPS.update_rng!(trace::LibtaskTrace)
new_rng = deepcopy(trace.rng)
taped_globals = trace.model.ctask.taped_globals
new_rng = deepcopy(taped_globals.rng)
trace.rng = new_rng
Libtask.set_taped_globals!(trace.model.ctask, trace.rng)
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(new_rng, taped_globals.other))
return trace
end

Expand All @@ -86,17 +97,18 @@ end
# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function AdvancedPS.forkr(trace::LibtaskTrace)
taped_globals = trace.model.ctask.taped_globals
rng = taped_globals.rng
newf = AdvancedPS.reset_model(trace.model.f)
Random123.set_counter!(trace.rng, 1)
Random123.set_counter!(rng, 1)
trace.rng = rng

ctask = Libtask.TapedTask(trace.rng, newf)
ctask = Libtask.TapedTask(TapedGlobals(rng, taped_globals.other), newf)
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)

# add backward reference
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
newtrace = AdvancedPS.Trace(new_tapedmodel, rng)
AdvancedPS.gen_refseed!(newtrace)

Libtask.set_taped_globals!(ctask, trace.rng) # Sync trace and rng
return newtrace
end

Expand Down
2 changes: 0 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ function observe end
function replay end
function addreference! end

current_trace() = current_task().storage[:__trace]

# We need this one to be visible outside of the extension for dispatching (Turing.jl).
struct LibtaskModel{F,T}
f::F
Expand Down
7 changes: 4 additions & 3 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,18 @@
@test consume(a.model.ctask) == 4
end

@testset "current trace" begin
@testset "Back-reference" begin
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end

function (model::TaskIdModel)()
# Just print the task it's running in
id = objectid(AdvancedPS.current_trace())
trace = Libtask.get_taped_globals(Any).other
id = objectid(trace)
return Libtask.produce(id)
end

trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
AdvancedPS.addreference!(trace.model.ctask.task, trace)
AdvancedPS.addreference!(trace.model.ctask, trace)

@test AdvancedPS.advance!(trace, false) === objectid(trace)
end
Expand Down
16 changes: 8 additions & 8 deletions test/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@

function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

# First observation.
AdvancedPS.observe(Normal(a, 2), 3)

# Second latent variable.
rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))

# Second observation.
Expand All @@ -55,10 +55,10 @@
end

function (m::FailSMCModel)()
rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))
if a >= 4
AdvancedPS.observe(Normal(b, 2), 1.5)
Expand All @@ -82,7 +82,7 @@

function (m::TestModel)()
# First hidden variables.
rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal(0, 1))
m.x = x = rand(rng, Bernoulli(1))
m.b = rand(rng, Gamma(2, 3))
Expand All @@ -91,7 +91,7 @@
AdvancedPS.observe(Bernoulli(x / 2), 1)

# Second hidden variable.
rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.c = rand(rng, Beta())

# Second observation.
Expand Down Expand Up @@ -167,11 +167,11 @@
end

function (m::DummyModel)()
rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal())
AdvancedPS.observe(Normal(), m.a)

rng = Libtask.get_taped_globals(Any)
rng = Libtask.get_taped_globals(Any).rng
m.b = rand(rng, Normal())
return AdvancedPS.observe(Normal(), m.b)
end
Expand Down
Loading