Skip to content

Commit 16b047f

Browse files
mhauruyebaipenelopeysm
authored
AdvancedPS v0.7 (and thus Libtask v0.9) support (#2585)
* AdvancedPS v0.7 support, work in progress * Fixing particle_mcmc.jl * Remove use of AdvancedPS.addreference! * Improve a comment * Update Project.toml (#2598) * Fix a bug and a test * Bump Libtask to 0.9.3 Co-authored-by: Hong Ge <[email protected]> * Fix seed setting, increase iterations * Increate a test iteration count --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent a5408fb commit 16b047f

File tree

7 files changed

+47
-39
lines changed

7 files changed

+47
-39
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ AbstractPPL = "0.11, 0.12, 0.13"
5454
Accessors = "0.1"
5555
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
5656
AdvancedMH = "0.8"
57-
AdvancedPS = "0.6.0"
57+
AdvancedPS = "0.7"
5858
AdvancedVI = "0.4"
5959
BangBang = "0.4.2"
6060
Bijectors = "0.14, 0.15"
@@ -67,7 +67,7 @@ DynamicHMC = "3.4"
6767
DynamicPPL = "0.36.3"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
70-
Libtask = "0.8.8"
70+
Libtask = "0.9.3"
7171
LinearAlgebra = "1"
7272
LogDensityProblems = "2"
7373
MCMCChains = "5, 6, 7"
@@ -85,7 +85,7 @@ Statistics = "1.6"
8585
StatsAPI = "1.6"
8686
StatsBase = "0.32, 0.33, 0.34"
8787
StatsFuns = "0.8, 0.9, 1"
88-
julia = "1.10.2"
88+
julia = "1.10.8"
8989

9090
[extras]
9191
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"

src/mcmc/particle_mcmc.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ function TracedModel(
2525
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
2626
)
2727
end
28-
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
29-
model, sampler, varinfo, (model.f, args...)
30-
)
28+
evaluator = (model.f, args...)
29+
return TracedModel(model, sampler, varinfo, evaluator)
3130
end
3231

3332
function AdvancedPS.advance!(
@@ -59,20 +58,10 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
5958
return trace
6059
end
6160

62-
function AdvancedPS.update_rng!(
63-
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
64-
)
65-
# Extract the `args`.
66-
args = trace.model.ctask.args
67-
# From `args`, extract the `SamplingContext`, which contains the RNG.
68-
sampling_context = args[3]
69-
rng = sampling_context.rng
70-
trace.rng = rng
71-
return trace
72-
end
73-
74-
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
75-
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
61+
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
62+
return Libtask.TapedTask(
63+
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
64+
)
7665
end
7766

7867
abstract type ParticleInference <: InferenceAlgorithm end
@@ -402,11 +391,11 @@ end
402391

403392
function trace_local_varinfo_maybe(varinfo)
404393
try
405-
trace = AdvancedPS.current_trace()
406-
return trace.model.f.varinfo
394+
trace = Libtask.get_taped_globals(Any).other
395+
return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo
407396
catch e
408397
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
409-
if e == KeyError(:__trace) || current_task().storage isa Nothing
398+
if e == KeyError(:task_variable)
410399
return varinfo
411400
else
412401
rethrow(e)
@@ -416,11 +405,10 @@ end
416405

417406
function trace_local_rng_maybe(rng::Random.AbstractRNG)
418407
try
419-
trace = AdvancedPS.current_trace()
420-
return trace.rng
408+
return Libtask.get_taped_globals(Any).rng
421409
catch e
422410
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
423-
if e == KeyError(:__trace) || current_task().storage isa Nothing
411+
if e == KeyError(:task_variable)
424412
return rng
425413
else
426414
rethrow(e)
@@ -481,6 +469,25 @@ function AdvancedPS.Trace(
481469

482470
tmodel = TracedModel(model, sampler, newvarinfo, rng)
483471
newtrace = AdvancedPS.Trace(tmodel, rng)
484-
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
485472
return newtrace
486473
end
474+
475+
# We need to tell Libtask which calls may have `produce` calls within them. In practice most
476+
# of these won't be needed, because of inlining and the fact that `might_produce` is only
477+
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
478+
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
479+
# `acclogp_observe!!` which is what calls `produce` and go up the call stack.
480+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true
481+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
482+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
483+
function Libtask.might_produce(
484+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
485+
)
486+
return true
487+
end
488+
function Libtask.might_produce(
489+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
490+
)
491+
return true
492+
end
493+
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ ADTypes = "1"
4343
AbstractMCMC = "5"
4444
AbstractPPL = "0.11, 0.12, 0.13"
4545
AdvancedMH = "0.6, 0.7, 0.8"
46-
AdvancedPS = "=0.6.0"
46+
AdvancedPS = "0.7"
4747
AdvancedVI = "0.4"
4848
Aqua = "0.8"
4949
BangBang = "0.4"

test/essential/container.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ using Turing
2323
model = test()
2424
trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())
2525

26-
# Make sure we link the traces
27-
@test haskey(trace.model.ctask.task.storage, :__trace)
26+
# Make sure the backreference from taped_globals to the trace is in place.
27+
@test trace.model.ctask.taped_globals.other === trace
2828

2929
res = AdvancedPS.advance!(trace, false)
3030
@test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1

test/mcmc/Inference.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ using Turing
4444
end
4545

4646
# Should also be stable with an explicit RNG
47-
seed = 5
48-
rng = Random.MersenneTwister(seed)
47+
local_seed = 5
48+
rng = Random.MersenneTwister(local_seed)
4949
for sampler in samplers
50-
Random.seed!(rng, seed)
50+
Random.seed!(rng, local_seed)
5151
chain1 = sample(rng, model, sampler, MCMCThreads(), 10, 4)
5252

53-
Random.seed!(rng, seed)
53+
Random.seed!(rng, local_seed)
5454
chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4)
5555

5656
@test chain1.value == chain2.value
@@ -256,9 +256,9 @@ using Turing
256256
pg = PG(10)
257257
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))
258258

259-
chn_s = sample(StableRNG(seed), testbb(obs), smc, 200)
260-
chn_p = sample(StableRNG(seed), testbb(obs), pg, 200)
261-
chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 200)
259+
chn_s = sample(StableRNG(seed), testbb(obs), smc, 2_000)
260+
chn_p = sample(StableRNG(seed), testbb(obs), pg, 2_000)
261+
chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 2_000)
262262

263263
check_numerical(chn_s, [:p], [meanp]; atol=0.05)
264264
check_numerical(chn_p, [:x], [meanp]; atol=0.1)
@@ -647,7 +647,7 @@ using Turing
647647
@model function e(x=1.0)
648648
return x ~ Normal()
649649
end
650-
# Can't test with HMC/NUTS because some AD backends error; see
650+
# Can't test with HMC/NUTS because some AD backends error; see
651651
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802
652652
@test sample(e(), IS(), 100) isa MCMCChains.Chains
653653
end

test/mcmc/ess.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ using Turing
6464
@varname(mu1) => ESS(),
6565
@varname(mu2) => ESS(),
6666
)
67-
chain = sample(StableRNG(seed), MoGtest_default, alg, 2000)
67+
chain = sample(StableRNG(seed), MoGtest_default, alg, 5000)
6868
check_MoGtest_default(chain; atol=0.1)
6969
end
7070

test/mcmc/particle_mcmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using Turing
3434

3535
tested = sample(normal(), SMC(), 100)
3636

37+
# TODO(mhauru) This needs an explanation for why it fails.
3738
# failing test
3839
@model function fail_smc()
3940
a ~ Normal(4, 5)

0 commit comments

Comments
 (0)