Skip to content

Commit a21f24d

Browse files
committed
Merge branch 'breaking' into mhauru/dppl-0.37
2 parents d2c1c92 + 465642e commit a21f24d

File tree

8 files changed

+53
-39
lines changed

8 files changed

+53
-39
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
[...]
44

5+
# 0.39.7
6+
7+
Update compatibility to AdvancedPS 0.7 and Libtask 0.9.
8+
9+
These new libraries provide significant speedups for particle MCMC methods.
10+
511
# 0.39.6
612

713
Bumped compatibility of AbstractPPL to include 0.13.

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ AbstractPPL = "0.11, 0.12, 0.13"
5454
Accessors = "0.1"
5555
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
5656
AdvancedMH = "0.8"
57-
AdvancedPS = "0.6.0"
57+
AdvancedPS = "0.7"
5858
AdvancedVI = "0.4"
5959
BangBang = "0.4.2"
6060
Bijectors = "0.14, 0.15"
@@ -67,7 +67,7 @@ DynamicHMC = "3.4"
6767
DynamicPPL = "0.37"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
70-
Libtask = "0.8.8"
70+
Libtask = "0.9.3"
7171
LinearAlgebra = "1"
7272
LogDensityProblems = "2"
7373
MCMCChains = "5, 6, 7"
@@ -85,7 +85,7 @@ Statistics = "1.6"
8585
StatsAPI = "1.6"
8686
StatsBase = "0.32, 0.33, 0.34"
8787
StatsFuns = "0.8, 0.9, 1"
88-
julia = "1.10.2"
88+
julia = "1.10.8"
8989

9090
[extras]
9191
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"

src/mcmc/particle_mcmc.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ function TracedModel(
2626
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
2727
)
2828
end
29-
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
30-
spl_model, sampler, varinfo, (spl_model.f, args...)
31-
)
29+
evaluator = (spl_model.f, args...)
30+
return TracedModel(spl_model, sampler, varinfo, evaluator)
3231
end
3332

3433
function AdvancedPS.advance!(
@@ -60,20 +59,10 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
6059
return Accessors.@set trace.model.varinfo = DynamicPPL.resetlogp!!(trace.model.varinfo)
6160
end
6261

63-
function AdvancedPS.update_rng!(
64-
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
65-
)
66-
# Extract the `args`.
67-
args = trace.model.ctask.args
68-
# From `args`, extract the `SamplingContext`, which contains the RNG.
69-
sampling_context = args[3]
70-
rng = sampling_context.rng
71-
trace.rng = rng
72-
return trace
73-
end
74-
75-
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
76-
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
62+
function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...)
63+
return Libtask.TapedTask(
64+
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
65+
)
7766
end
7867

7968
abstract type ParticleInference <: InferenceAlgorithm end
@@ -403,11 +392,11 @@ end
403392

404393
function trace_local_varinfo_maybe(varinfo)
405394
try
406-
trace = AdvancedPS.current_trace()
407-
return trace.model.f.varinfo
395+
trace = Libtask.get_taped_globals(Any).other
396+
return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo
408397
catch e
409398
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
410-
if e == KeyError(:__trace) || current_task().storage isa Nothing
399+
if e == KeyError(:task_variable)
411400
return varinfo
412401
else
413402
rethrow(e)
@@ -417,11 +406,10 @@ end
417406

418407
function trace_local_rng_maybe(rng::Random.AbstractRNG)
419408
try
420-
trace = AdvancedPS.current_trace()
421-
return trace.rng
409+
return Libtask.get_taped_globals(Any).rng
422410
catch e
423411
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
424-
if e == KeyError(:__trace) || current_task().storage isa Nothing
412+
if e == KeyError(:task_variable)
425413
return rng
426414
else
427415
rethrow(e)
@@ -487,6 +475,25 @@ function AdvancedPS.Trace(
487475

488476
tmodel = TracedModel(model, sampler, newvarinfo, rng)
489477
newtrace = AdvancedPS.Trace(tmodel, rng)
490-
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
491478
return newtrace
492479
end
480+
481+
# We need to tell Libtask which calls may have `produce` calls within them. In practice most
482+
# of these won't be needed, because of inlining and the fact that `might_produce` is only
483+
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
484+
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
485+
# `acclogp_observe!!` which is what calls `produce` and go up the call stack.
486+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true
487+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
488+
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
489+
function Libtask.might_produce(
490+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
491+
)
492+
return true
493+
end
494+
function Libtask.might_produce(
495+
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
496+
)
497+
return true
498+
end
499+
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ ADTypes = "1"
4343
AbstractMCMC = "5"
4444
AbstractPPL = "0.11, 0.12, 0.13"
4545
AdvancedMH = "0.6, 0.7, 0.8"
46-
AdvancedPS = "=0.6.0"
46+
AdvancedPS = "0.7"
4747
AdvancedVI = "0.4"
4848
Aqua = "0.8"
4949
BangBang = "0.4"

test/essential/container.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ using Turing
2323
model = test()
2424
trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())
2525

26-
# Make sure we link the traces
27-
@test haskey(trace.model.ctask.task.storage, :__trace)
26+
# Make sure the backreference from taped_globals to the trace is in place.
27+
@test trace.model.ctask.taped_globals.other === trace
2828

2929
res = AdvancedPS.advance!(trace, false)
3030
@test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1

test/mcmc/Inference.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ using Turing
4444
end
4545

4646
# Should also be stable with an explicit RNG
47-
seed = 5
48-
rng = Random.MersenneTwister(seed)
47+
local_seed = 5
48+
rng = Random.MersenneTwister(local_seed)
4949
for sampler in samplers
50-
Random.seed!(rng, seed)
50+
Random.seed!(rng, local_seed)
5151
chain1 = sample(rng, model, sampler, MCMCThreads(), 10, 4)
5252

53-
Random.seed!(rng, seed)
53+
Random.seed!(rng, local_seed)
5454
chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4)
5555

5656
@test chain1.value == chain2.value
@@ -211,9 +211,9 @@ using Turing
211211
pg = PG(10)
212212
gibbs = Gibbs(:p => HMC(0.2, 3), :x => PG(10))
213213

214-
chn_s = sample(StableRNG(seed), testbb(obs), smc, 200)
215-
chn_p = sample(StableRNG(seed), testbb(obs), pg, 200)
216-
chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 200)
214+
chn_s = sample(StableRNG(seed), testbb(obs), smc, 2_000)
215+
chn_p = sample(StableRNG(seed), testbb(obs), pg, 2_000)
216+
chn_g = sample(StableRNG(seed), testbb(obs), gibbs, 2_000)
217217

218218
check_numerical(chn_s, [:p], [meanp]; atol=0.05)
219219
check_numerical(chn_p, [:x], [meanp]; atol=0.1)
@@ -602,7 +602,7 @@ using Turing
602602
@model function e(x=1.0)
603603
return x ~ Normal()
604604
end
605-
# Can't test with HMC/NUTS because some AD backends error; see
605+
# Can't test with HMC/NUTS because some AD backends error; see
606606
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802
607607
@test sample(e(), IS(), 100) isa MCMCChains.Chains
608608
end

test/mcmc/ess.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ using Turing
6464
@varname(mu1) => ESS(),
6565
@varname(mu2) => ESS(),
6666
)
67-
chain = sample(StableRNG(seed), MoGtest_default, alg, 2000)
67+
chain = sample(StableRNG(seed), MoGtest_default, alg, 5000)
6868
check_MoGtest_default(chain; atol=0.1)
6969
end
7070

test/mcmc/particle_mcmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using Turing
3434

3535
tested = sample(normal(), SMC(), 100)
3636

37+
# TODO(mhauru) This needs an explanation for why it fails.
3738
# failing test
3839
@model function fail_smc()
3940
a ~ Normal(4, 5)

0 commit comments

Comments
 (0)