Skip to content

Commit 7072ff1

Browse files
authored
Mark all methods with DynamicPPL.Model as produceable (#2780)
This needs a Libtask release and version bump, which I'll handle once JuliaRegistrator does its things. This essentially implements the plan described in TuringLang/Libtask.jl#217. A lot of the issues stemming from Libtask not picking up model evaluators either with keyword arguments, or in submodels, can be fixed by simply declaring that **every** method that dispatches on `DynamicPPL.Model` is produceable. The mechanism for this is implemented in TuringLang/Libtask.jl#218, and this PR makes use of that. **For the end-user, this means that we guarantee correctness where models either have submodels or where models have keyword arguments. The user no longer has to mark models with keyword arguments as `@might_produce`.** I tested performance, and there is no regression — in fact there is a small speedup (although that is probably benchmark noise): ## Submodel case This was the issue #2772 where non-inlined submodels were not correctly picked up. #2778 fixed this with a strategy that was similar to that in this PR, but was slightly more limited (this PR handles both submodels and keyword arguments together). ```julia using Turing, StableRNGs, Test @model function inner(y, x) @noinline y ~ Normal(x) end @model function nested(y) x ~ Normal() a ~ to_submodel(inner(y, x)) end m1 = nested(1.0) @time sample(StableRNG(468), m1, PG(10), 2000; chain_type=Any, progress=false); # 2.585299 seconds on #2778 # 2.523017 seconds on this PR ``` ## Keyword argument case This was the long-standing issue where models with keyword arguments were originally not picked up by Libtask, and since v0.42.5, could be, but relied on the user themselves manually declaring `Libtask.@might_produce`. ```julia @model function withkw(; y=0.0) x ~ Normal() y ~ Normal(x) end m1 = withkw(y=10.0); # Turing.@might_produce(withkw) @time sample(StableRNG(468), m1, PG(10), 2000; chain_type=Any, progress=false); # withkw case # 4.741234 seconds on main (requiring @might_produce) # 4.441797 seconds on this PR (and not requiring @might_produce) ```
1 parent 9be3ed5 commit 7072ff1

File tree

7 files changed

+81
-47
lines changed

7 files changed

+81
-47
lines changed

HISTORY.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# 0.42.9
2+
3+
Improve handling of model evaluator functions with Libtask.
4+
5+
This means that when running SMC or PG on a model with keyword arguments, you no longer need to use `@might_produce` (see patch notes of v0.42.5 for more details on this).
6+
7+
It also means that submodels with observations inside will now be reliably handled by the SMC/PG samplers, which was not the case before (the observations were only picked up if the submodel was inlined by the Julia compiler, which could lead to correctness issues).
8+
19
# 0.42.8
210

311
Add support for `TensorBoardLogger.jl` via `AbstractMCMC.mcmc_callback`.

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.42.8"
3+
version = "0.42.9"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -65,7 +65,7 @@ DynamicHMC = "3.4"
6565
DynamicPPL = "0.39.1"
6666
EllipticalSliceSampling = "0.5, 1, 2"
6767
ForwardDiff = "0.10.3, 1"
68-
Libtask = "0.9.5"
68+
Libtask = "0.9.14"
6969
LinearAlgebra = "1"
7070
LogDensityProblems = "2"
7171
MCMCChains = "5, 6, 7"

src/mcmc/particle_mcmc.jl

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ function AbstractMCMC.sample(
123123
)
124124
check_model && _check_model(model, sampler)
125125
error_if_threadsafe_eval(model)
126-
check_model_kwargs(model)
127126
# need to add on the `nparticles` keyword argument for `initialstep` to make use of
128127
return AbstractMCMC.mcmcsample(
129128
rng,
@@ -138,28 +137,6 @@ function AbstractMCMC.sample(
138137
)
139138
end
140139

141-
function check_model_kwargs(model::DynamicPPL.Model)
142-
if !isempty(model.defaults)
143-
# If there are keyword arguments, we need to check that the user has
144-
# accounted for this by overloading `might_produce`.
145-
might_produce = Libtask.might_produce(typeof((Core.kwcall, NamedTuple(), model.f)))
146-
if !might_produce
147-
io = IOBuffer()
148-
ctx = IOContext(io, :color => true)
149-
print(
150-
ctx,
151-
"Models with keyword arguments need special treatment to be used" *
152-
" with particle methods. Please run:\n\n",
153-
)
154-
printstyled(
155-
ctx, " Turing.@might_produce($(model.f))"; bold=true, color=:blue
156-
)
157-
print(ctx, "\n\nbefore sampling from this model with particle methods.\n")
158-
error(String(take!(io)))
159-
end
160-
end
161-
end
162-
163140
function Turing.Inference.initialstep(
164141
rng::AbstractRNG,
165142
model::DynamicPPL.Model,
@@ -169,7 +146,6 @@ function Turing.Inference.initialstep(
169146
discard_sample=false,
170147
kwargs...,
171148
)
172-
check_model_kwargs(model)
173149
# Reset the VarInfo.
174150
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
175151
vi = DynamicPPL.empty!!(vi)
@@ -292,7 +268,6 @@ function Turing.Inference.initialstep(
292268
kwargs...,
293269
)
294270
error_if_threadsafe_eval(model)
295-
check_model_kwargs(model)
296271
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
297272

298273
# Create a new set of particles
@@ -534,6 +509,9 @@ Libtask.@might_produce(DynamicPPL.tilde_observe!!)
534509
# Could tilde_assume!! have tighter type bounds on the arguments, namely a GibbsContext?
535510
# That's the only thing that makes tilde_assume calls result in tilde_observe calls.
536511
Libtask.@might_produce(DynamicPPL.tilde_assume!!)
537-
Libtask.@might_produce(DynamicPPL.evaluate!!)
538-
Libtask.@might_produce(DynamicPPL.init!!)
539-
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true
512+
513+
# This handles all models and submodel evaluator functions (including those with keyword
514+
# arguments). The key to this is realising that all model evaluator functions take
515+
# DynamicPPL.Model as an argument, so we can just check for that. See
516+
# https://github.com/TuringLang/Libtask.jl/issues/217.
517+
Libtask.might_produce_if_sig_contains(::Type{<:DynamicPPL.Model}) = true

test/Aqua.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
11
module AquaTests
22

33
using Aqua: Aqua
4+
using Libtask: Libtask
45
using Turing
56

6-
# We test ambiguities separately because it catches a lot of problems
7-
# in dependencies but we test it for Turing.
8-
Aqua.test_ambiguities([Turing])
7+
# We test ambiguities specifically only for Turing, because testing ambiguities for all
8+
# packages in the environment leads to a lot of ambiguities from dependencies that we cannot
9+
# control.
10+
#
11+
# `Libtask.might_produce` is excluded because the `@might_produce` macro generates a lot of
12+
# ambiguities that will never happen in practice.
13+
#
14+
# Specifically, when you write `@might_produce f` for a function `f` that has methods that
15+
# take keyword arguments, we have to generate a `might_produce` method for
16+
# `Type{<:Tuple{<:Function,...,typeof(f)}}`. There is no way to circumvent this: see
17+
# https://github.com/TuringLang/Libtask.jl/issues/197. This in turn will cause method
18+
# ambiguities with any other function, say `g`, for which
19+
# `::Type{<:Tuple{typeof(g),Vararg}}` is marked as produceable.
20+
#
21+
# To avoid the method ambiguities, we *could* manually spell out `might_produce` methods for
22+
# each method of `g` manually instead of using Vararg, but that would be both very verbose
23+
# and fragile. It would also not provide any real benefit since those ambiguities are not
24+
# meaningful in practice (in particular, to trigger this we would need to call `g(..., f)`,
25+
# which is incredibly unlikely).
26+
Aqua.test_ambiguities([Turing]; exclude=[Libtask.might_produce])
927
Aqua.test_all(Turing; ambiguities=false)
1028

1129
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1717
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1818
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1919
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
20+
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2021
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2122
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2223
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -57,6 +58,7 @@ DynamicPPL = "0.39.6"
5758
FiniteDifferences = "0.10.8, 0.11, 0.12"
5859
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
5960
HypothesisTests = "0.11"
61+
Libtask = "0.9.14"
6062
LinearAlgebra = "1"
6163
LogDensityProblems = "2"
6264
LogDensityProblemsAD = "1.4"

test/mcmc/Inference.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,12 @@ using Turing
314314
return priors
315315
end
316316

317-
@test_throws ErrorException chain = sample(
318-
StableRNG(seed), gauss2(; x=x), PG(10), 10
319-
)
320-
@test_throws ErrorException chain = sample(
321-
StableRNG(seed), gauss2(; x=x), SMC(), 10
322-
)
323-
324-
@test_throws ErrorException chain = sample(
317+
chain = sample(StableRNG(seed), gauss2(; x=x), PG(10), 10)
318+
chain = sample(StableRNG(seed), gauss2(; x=x), SMC(), 10)
319+
chain = sample(
325320
StableRNG(seed), gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10
326321
)
327-
@test_throws ErrorException chain = sample(
322+
chain = sample(
328323
StableRNG(seed), gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10
329324
)
330325

test/mcmc/particle_mcmc.jl

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,26 +161,59 @@ end
161161
@test mean(c[:x]) > 0.7
162162
end
163163

164-
# https://github.com/TuringLang/Turing.jl/issues/2007
165164
@testset "keyword argument handling" begin
166165
@model function kwarg_demo(y; n=0.0)
167166
x ~ Normal(n)
168167
return y ~ Normal(x)
169168
end
170-
@test_throws "Models with keyword arguments" sample(kwarg_demo(5.0), PG(20), 10)
171169

172-
# Check that enabling `might_produce` does allow sampling
173-
@might_produce kwarg_demo
174170
chain = sample(StableRNG(468), kwarg_demo(5.0), PG(20), 1000)
175171
@test chain isa MCMCChains.Chains
176172
@test mean(chain[:x]) 2.5 atol = 0.2
177173

178-
# Check that the keyword argument's value is respected
179174
chain2 = sample(StableRNG(468), kwarg_demo(5.0; n=10.0), PG(20), 1000)
180175
@test chain2 isa MCMCChains.Chains
181176
@test mean(chain2[:x]) 7.5 atol = 0.2
182177
end
183178

179+
@testset "submodels without kwargs" begin
180+
@model function inner(y, x)
181+
# Mark as noinline explicitly to make sure that behaviour is not reliant on the
182+
# Julia compiler inlining it.
183+
# See https://github.com/TuringLang/Turing.jl/issues/2772
184+
@noinline
185+
return y ~ Normal(x)
186+
end
187+
@model function nested(y)
188+
x ~ Normal()
189+
return a ~ to_submodel(inner(y, x))
190+
end
191+
m1 = nested(1.0)
192+
chn = sample(StableRNG(468), m1, PG(10), 1000)
193+
@test mean(chn[:x]) 0.5 atol = 0.1
194+
end
195+
196+
@testset "submodels with kwargs" begin
197+
@model function inner_kwarg(y; n=0.0)
198+
@noinline # See above
199+
x ~ Normal(n)
200+
return y ~ Normal(x)
201+
end
202+
@model function outer_kwarg1()
203+
return a ~ to_submodel(inner_kwarg(5.0))
204+
end
205+
m1 = outer_kwarg1()
206+
chn1 = sample(StableRNG(468), m1, PG(10), 1000)
207+
@test mean(chn1[Symbol("a.x")]) 2.5 atol = 0.2
208+
209+
@model function outer_kwarg2(n)
210+
return a ~ to_submodel(inner_kwarg(5.0; n=n))
211+
end
212+
m2 = outer_kwarg2(10.0)
213+
chn2 = sample(StableRNG(468), m2, PG(10), 1000)
214+
@test mean(chn2[Symbol("a.x")]) 7.5 atol = 0.2
215+
end
216+
184217
@testset "refuses to run threadsafe eval" begin
185218
# PG can't run models that have nondeterministic evaluation order,
186219
# so it should refuse to run models marked as threadsafe.

0 commit comments

Comments
 (0)