Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ AdvancedPSLibtaskExt = "Libtask"
[compat]
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.9"
Libtask = "0.9.2"
Random = "<0.0.1, 1"
Random123 = "1.3"
Requires = "1.0"
Expand Down
56 changes: 37 additions & 19 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ else
end

# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
# and sometimes both an RNG and other information. In Turing.jl this other information
# and sometimes both an RNG and other information. In Turing.jl the other information
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
struct TapedGlobals{RngType}
Expand Down Expand Up @@ -49,6 +49,29 @@ end

const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}

"""Get the RNG from a `LibtaskTrace`."""
function get_rng(trace::LibtaskTrace)
return trace.model.ctask.taped_globals.rng
end

"""Set the RNG for a `LibtaskTrace`."""
function set_rng!(trace::LibtaskTrace, rng::Random.AbstractRNG)
taped_globals = trace.model.ctask.taped_globals
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
trace.rng = rng
return trace
end

"""Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
function set_other_global!(trace::LibtaskTrace, other)
rng = get_rng(trace)
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other))
return trace
end

"""Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
get_other_global(trace::LibtaskTrace) = trace.model.ctask.taped_globals.other

function AdvancedPS.Trace(
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
)
Expand All @@ -58,30 +81,25 @@ end
# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
taped_globals = trace.model.ctask.taped_globals
rng = taped_globals.rng
rng = get_rng(trace)
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
AdvancedPS.inc_counter!(rng)

Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
trace.rng = rng

set_rng!(trace, rng)
# Move to next step
return Libtask.consume(trace.model.ctask)
end

# create a backward reference in task_local_storage
function AdvancedPS.addreference!(task::Libtask.TapedTask, trace::LibtaskTrace)
rng = task.taped_globals.rng
Libtask.set_taped_globals!(task, TapedGlobals(rng, trace))
return task
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that addreference! is no longer needed, given that it stores a self-reference? If so, can we remove it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last bug that I had here was in fact a missing call to addreference!, I think in fork. We also need to call it in Turing's particle_mcmc, to keep the reference in sync so that we can access the varinfo of the trace from within the TapedTask. There's probably a way to get rid of it, but that would require some refactoring, which would require me learning better what is going on. If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?

I suggest removing it from this PR since addreference! adds another layer of indirection. Given that this PR is already a breaking release and requires updates to Turing, I think it is worth it.

More context:

I think we are safe if we always store and retrieve (rng, varinfo) via set_taped_globals! / get_taped_globals. The entire motivation for addreference! is to be able to copy varinfo external to a live particle (e.g. during the resampling step, multiple child particles share the same varinfo). It should produce correct results if particles always retrieve varinfo via get_taped_global. Otherwise, something weird is happening.

Set a backreference so that the TapedTask in `trace` stores the `trace` itself in the
taped globals.
"""
function AdvancedPS.addreference!(trace::LibtaskTrace)
set_other_global!(trace, trace)
return trace
end

function AdvancedPS.update_rng!(trace::LibtaskTrace)
taped_globals = trace.model.ctask.taped_globals
new_rng = deepcopy(taped_globals.rng)
trace.rng = new_rng
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(new_rng, taped_globals.other))
set_rng!(trace, deepcopy(get_rng(trace)))
return trace
end

Expand All @@ -91,19 +109,19 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
AdvancedPS.update_rng!(newtrace)
isref && AdvancedPS.delete_retained!(newtrace.model.f)
isref && delete_seeds!(newtrace)
AdvancedPS.addreference!(newtrace)
return newtrace
end

# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function AdvancedPS.forkr(trace::LibtaskTrace)
taped_globals = trace.model.ctask.taped_globals
rng = taped_globals.rng
rng = get_rng(trace)
newf = AdvancedPS.reset_model(trace.model.f)
Random123.set_counter!(rng, 1)
trace.rng = rng

ctask = Libtask.TapedTask(TapedGlobals(rng, taped_globals.other), newf)
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)

# add backward reference
Expand Down
2 changes: 1 addition & 1 deletion test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
end

trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
AdvancedPS.addreference!(trace.model.ctask, trace)
AdvancedPS.addreference!(trace)

@test AdvancedPS.advance!(trace, false) === objectid(trace)
end
Expand Down
Loading