Skip to content

Commit 0f7548d

Browse files
committed
Support for submodels (#233)
Part of the motivations for #221 and #222 was so we could add submodels/model-nesting. Well, now we can. Special thanks to @devmotion who reviewed those PRs (several times), improving them significantly, made additional PRs and suggested the current impl of the `@submodel` ❤️ EDIT: We fixed the performance:) This now has zero runtime overhead! See comment-section. EDIT 2: Thanks to @devmotion, we can now alos deal with dynamically specified prefices! - [Motivating example: AR1-prior](#org46a90a5) - [Demos](#org7e05701) - [Can it ever fail?](#org75acb71) - [Benchmarks](#orga99bcf4) <a id="org46a90a5"></a> # Motivating example: AR1-prior ```julia using Turing using DynamicPPL ``` ```julia # Could have made model which samples `num_obs` AR1 samples simulatenously, # but for the sake of showing off dynamic prefixes, we'll only use a vector-implementation. # The matrix implementation will be quite a bit faster too, but oh well. @model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV} η ~ MvNormal(num_steps, 1.0) δ = sqrt(1 - α^2) x = TV(undef, num_steps) x[1] = η[1] @inbounds for t = 2:num_steps x[t] = @. α * x[t - 1] + δ * η[t] end return @. μ + σ * x end # Generate an observation σ_obs = 0.1 num_obs = 5 num_steps = 10 ar1 = AR1(num_steps, 0.5, 1.0, 1.0) ys = mapreduce(hcat, 1:num_obs) do i ar1() + σ_obs * randn(num_steps) end ``` 10×5 Matrix{Float64}: 2.30189 0.301618 1.73268 -0.65096 1.46835 2.11187 -1.34878 2.3728 1.02125 3.28422 -0.249064 0.769488 1.34044 3.22175 2.52196 -0.25863 -0.216914 0.528954 3.04756 3.8234 0.372122 0.473511 0.708068 0.76197 0.202003 0.41487 0.759435 1.80162 0.790204 0.12331 1.32585 0.567929 2.74316 1.0874 2.82701 1.84307 1.16138 1.36382 0.735388 1.07423 3.20139 0.75177 1.57236 0.865401 -0.315341 1.22479 1.35688 2.8239 0.597959 0.587955 ```julia @model function demo(y) α ~ Uniform() μ ~ Normal() σ ~ truncated(Normal(), 0, Inf) num_steps = size(y, 1) num_obs = size(y, 2) @inbounds for i = 1:num_obs x = @SubModel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) y[:, i] ~ MvNormal(x, 0.1) end end; m = demo(y); vi = VarInfo(m); ``` ```julia keys(vi) ``` 8-element Vector{VarName{sym, Tuple{}} where sym}: α μ σ ar1_1.η ar1_2.η ar1_3.η ar1_4.η ar1_5.η ```julia vi[@varname α] ``` 0.9383208224122919 ```julia chain = sample(m, NUTS(1_000, 0.8), 3_000); ``` ┌ Info: Found initial step size │ ϵ = 0.025 └ @ Turing.Inference /home/tor/.julia/packages/Turing/rHLGJ/src/inference/hmc.jl:188 Sampling: 100%|█████████████████████████████████████████| Time: 0:04:00 ```julia chain[1001:end, [:α, :μ, :σ], :] ``` Chains MCMC chain (2000×3×1 Array{Float64, 3}): Iterations = 1001:3000 Thinning interval = 1 Chains = 1 Samples per chain = 2000 parameters = α, μ, σ internals = Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 α 0.5474 0.1334 0.0030 0.0073 159.6969 0.9995 μ 1.0039 0.2733 0.0061 0.0168 169.9106 1.0134 σ 1.1294 0.1807 0.0040 0.0106 166.8670 0.9998 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.2684 0.4625 0.5534 0.6445 0.7861 μ 0.4248 0.8227 1.0241 1.2011 1.4801 σ 0.8781 1.0018 1.0989 1.2239 1.5472 Yay! We recovered the true parameters :tada: ```julia @benchmark $m($vi) ``` BenchmarkTools.Trial: memory estimate: 12.05 KiB allocs estimate: 123 -------------- minimum time: 15.091 μs (0.00% GC) median time: 17.861 μs (0.00% GC) mean time: 19.582 μs (5.23% GC) maximum time: 10.293 ms (99.46% GC) -------------- samples: 10000 evals/sample: 1 <a id="org7e05701"></a> # Demos ```julia using DynamicPPL, Distributions ``` ┌ Info: Precompiling DynamicPPL [366bfd00-2699-11ea-058f-f148b4cae6d8] └ @ Base loading.jl:1317 ```julia @model function demo1(x) x ~ Normal() end; @model function demo2(x, y) @SubModel demo1(x) y ~ Uniform() end false; m2 = demo2(missing, missing); vi2 = VarInfo(m2); keys(vi2) ``` 2-element Vector{VarName{sym, Tuple{}} where sym}: x y ```julia println(vi2[VarName(Symbol("x"))]) println(vi2[VarName(Symbol("y"))]) ``` 0.3069117531180063 0.7325324947386318 We can also `observe` without issues: ```julia @model function demo2(x, y) @SubModel demo1(x) y ~ Normal(x) end false; m2 = demo2(1000.0, missing); vi2 = VarInfo(m2); keys(vi2) ``` 1-element Vector{VarName{:y, Tuple{}}}: y ```julia vi2[@varname y] ``` 1000.3905079427211 ```julia DynamicPPL.getlogp(vi2) ``` -500001.9141252931 But what if the models have the same variable-names?! "Sure, this is cool and all, but can we even use the values from the nested values in the parent model?" ```julia @model function demo_return(x) x ~ Normal() return x end; @model function demo_useval(x, y) x1 = @SubModel sub1 demo_return(x) x2 = @SubModel sub2 demo_return(y) z ~ Normal(x1 + x2 + 100, 1.0) end false; vi = VarInfo(demo_useval(missing, missing)); keys(vi) ``` 3-element Vector{VarName{sym, Tuple{}} where sym}: sub1.x sub2.x z ```julia vi[@varname z] ``` 101.09066854862154 And just to prove a point: ```julia @model function nested(x, y) @SubModel 1 nested1(x, y) y ~ Uniform() end false; @model function nested1(x, y) @SubModel 2 nested2(x, y) y ~ Uniform() end false; @model function nested2(x, y) z = @SubModel 3 nested3(x, y) y ~ Normal(z, 1.0) end false; @model function nested3(x, y) x ~ Uniform() y ~ Normal(-100.0, 1.0) return x + 1000 end false; m = nested(missing, missing); vi = VarInfo(m); keys(vi) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: 1.2.3.x 1.2.3.y 1.2.y 1.y y ```julia vi[VarName(Symbol("1.2.y"))] ``` 1000.5609156083766 ```julia DynamicPPL.getlogp(vi) ``` -4.620040828101227 <a id="org75acb71"></a> # Can it ever fail? Yeah, if the user doesn't provide the prefix, it can: ```julia @model function nested(x, y) @SubModel nested1(x, y) y ~ Uniform() end false; @model function nested1(x, y) @SubModel nested2(x, y) y ~ Uniform() end false; @model function nested2(x, y) z = @SubModel nested3(x, y) y ~ Normal(z, 1.0) end false; @model function nested3(x, y) x ~ Uniform() y ~ Normal(-100.0, 1.0) return x + 1000 end false; m = nested(missing, missing); vi = VarInfo(m); keys(vi) ``` 2-element Vector{VarName{sym, Tuple{}} where sym}: x y ```julia # Inner-most value is recorded (i.e. the first one reached) vi[@varname y] ``` -100.16836599596732 And it messes up the logp computation: ```julia DynamicPPL.getlogp(vi) ``` -Inf But I could imagine there's a way for us to fix this, or at least warn the user when this happens. <a id="orga99bcf4"></a> # Benchmarks At this point you're probably wondering, "but does it have any overhead (at runtime)?". For a "shallow" nestings, nah, but if you go deep enough there seems to be a tiny bit (likely because we're calling the "constructor" for the model): ```julia using BenchmarkTools @model function base(x, y) x ~ Uniform() y ~ Uniform() y1 ~ Uniform() z = x + 1000 y12 ~ Normal() y123 ~ Normal(-100.0, 1.0) end m1 = base(missing, missing); vi1 = VarInfo(m1); ``` ```julia @model function nested_shallow(x, y) @SubModel 1 nested1_shallow(x, y) y ~ Uniform() end false; @model function nested1_shallow(x, y) x ~ Uniform() y ~ Uniform() z = x + 1000 y12 ~ Normal() y123 ~ Normal(-100.0, 1.0) end false; m2 = nested_shallow(missing, missing); vi2 = VarInfo(m2); ``` ```julia @model function nested(x, y) @SubModel 1 nested1(x, y) y ~ Uniform() end false; @model function nested1(x, y) @SubModel 2 nested2(x, y) y ~ Uniform() end false; @model function nested2(x, y) z = @SubModel 3 nested3(x, y) y ~ Normal(z, 1.0) end false; @model function nested3(x, y) x ~ Uniform() y ~ Normal(-100.0, 1.0) return x + 1000 end m3 = nested(missing, missing); vi3 = VarInfo(m3); ``` ```julia @model function nested_noprefix(x, y) @SubModel nested_noprefix1(x, y) y ~ Uniform() end false; @model function nested_noprefix1(x, y) @SubModel nested_noprefix2(x, y) y1 ~ Uniform() end false; @model function nested_noprefix2(x, y) z = @SubModel nested_noprefix3(x, y) y2 ~ Normal(z, 1.0) end false; @model function nested_noprefix3(x, y) x ~ Uniform() y3 ~ Normal(-100.0, 1.0) return x + 1000 end m4 = nested_noprefix(missing, missing); vi4 = VarInfo(m4); ``` ```julia keys(vi1) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: x y y1 y12 y123 ```julia keys(vi2) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: 1.x 1.y 1.y12 1.y123 y ```julia keys(vi3) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: 1.2.3.x 1.2.3.y 1.2.y 1.y y ```julia keys(vi4) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: x y3 y2 y1 y ```julia @benchmark $m1($vi1) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.714 μs (0.00% GC) median time: 1.747 μs (0.00% GC) mean time: 1.835 μs (0.00% GC) maximum time: 6.894 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 ```julia @benchmark $m2($vi2) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.759 μs (0.00% GC) median time: 1.778 μs (0.00% GC) mean time: 1.819 μs (0.00% GC) maximum time: 5.563 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 ```julia @benchmark $m3($vi3) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.718 μs (0.00% GC) median time: 1.746 μs (0.00% GC) mean time: 1.787 μs (0.00% GC) maximum time: 5.758 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 ```julia @benchmark $m4($vi4) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.672 μs (0.00% GC) median time: 1.696 μs (0.00% GC) mean time: 1.756 μs (0.00% GC) maximum time: 4.882 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 Notice that the number of allocations have increased for the deeply nested model. Seems like the Julia compiler isn't too good at inferring the return-types of Turing-models? This seems to be the case too by looking at the lowered code. I haven't given this too much thought yet btw; likely is a way for us to help the compiler here.
1 parent f7531ba commit 0f7548d

File tree

7 files changed

+164
-4
lines changed

7 files changed

+164
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.19"
3+
version = "0.10.20"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/DynamicPPL.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export AbstractVarInfo,
7979
LikelihoodContext,
8080
PriorContext,
8181
MiniBatchContext,
82+
PrefixContext,
8283
assume,
8384
dot_assume,
8485
observer,
@@ -96,7 +97,9 @@ export AbstractVarInfo,
9697
logjoint,
9798
pointwise_loglikelihoods,
9899
# Convenience macros
99-
@addlogprob!
100+
@addlogprob!,
101+
@submodel
102+
100103

101104
# Reexport
102105
using Distributions: loglikelihood
@@ -124,5 +127,6 @@ include("compiler.jl")
124127
include("prob_macro.jl")
125128
include("compat/ad.jl")
126129
include("loglikelihoods.jl")
130+
include("submodel_macro.jl")
127131

128132
end # module

src/compiler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
5454
#################
5555

5656
"""
57-
@model(expr[, warn = true])
57+
@model(expr[, warn = false])
5858
5959
Macro to specify a probabilistic model.
6060
@@ -73,7 +73,7 @@ end
7373
7474
To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
7575
"""
76-
macro model(expr, warn=true)
76+
macro model(expr, warn=false)
7777
# include `LineNumberNode` with information about the call site in the
7878
# generated function for easier debugging and interpretation of error messages
7979
esc(model(__module__, __source__, expr, warn))

src/context_implementations.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ end
3939
function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi)
4040
return tilde(rng, ctx.ctx, sampler, right, left, inds, vi)
4141
end
42+
function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi)
43+
return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi)
44+
end
4245

4346
"""
4447
tilde_assume(rng, ctx, sampler, right, vn, inds, vi)
@@ -75,6 +78,9 @@ end
7578
function tilde(ctx::MiniBatchContext, sampler, right, left, vi)
7679
return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi)
7780
end
81+
function tilde(ctx::PrefixContext, sampler, right, left, vi)
82+
return tilde(ctx.ctx, sampler, right, left, vi)
83+
end
7884

7985
"""
8086
tilde_observe(ctx, sampler, right, left, vname, vinds, vi)

src/contexts.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,29 @@ end
5252
function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints)
5353
return MiniBatchContext(ctx, npoints/batch_size)
5454
end
55+
56+
57+
struct PrefixContext{Prefix, C} <: AbstractContext
58+
ctx::C
59+
end
60+
PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} = PrefixContext{Prefix, typeof(ctx)}(ctx)
61+
62+
const PREFIX_SEPARATOR = Symbol(".")
63+
64+
function PrefixContext{PrefixInner}(
65+
ctx::PrefixContext{PrefixOuter}
66+
) where {PrefixInner, PrefixOuter}
67+
if @generated
68+
:(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}(ctx.ctx))
69+
else
70+
PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx)
71+
end
72+
end
73+
74+
function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix, Sym}
75+
if @generated
76+
return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing))
77+
else
78+
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
79+
end
80+
end

src/submodel_macro.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
macro submodel(expr)
2+
return quote
3+
_evaluate(
4+
$(esc(:__rng__)),
5+
$(esc(expr)),
6+
$(esc(:__varinfo__)),
7+
$(esc(:__sampler__)),
8+
$(esc(:__context__))
9+
)
10+
end
11+
end
12+
13+
macro submodel(prefix, expr)
14+
return quote
15+
_evaluate(
16+
$(esc(:__rng__)),
17+
$(esc(expr)),
18+
$(esc(:__varinfo__)),
19+
$(esc(:__sampler__)),
20+
PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))
21+
)
22+
end
23+
end

test/compiler.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,107 @@ end
314314
@test demo2()() == 42
315315
end
316316

317+
@testset "submodel" begin
318+
# No prefix, 1 level.
319+
@model function demo1(x)
320+
x ~ Normal()
321+
end;
322+
@model function demo2(x, y)
323+
@submodel demo1(x)
324+
y ~ Uniform()
325+
end;
326+
# No observation.
327+
m = demo2(missing, missing);
328+
vi = VarInfo(m);
329+
ks = keys(vi)
330+
@test VarName(:x) ks
331+
@test VarName(:y) ks
332+
333+
# Observation in top-level.
334+
m = demo2(missing, 1.0);
335+
vi = VarInfo(m);
336+
ks = keys(vi)
337+
@test VarName(:x) ks
338+
@test VarName(:y) ks
339+
340+
# Observation in nested model.
341+
m = demo2(1000.0, missing);
342+
vi = VarInfo(m);
343+
ks = keys(vi)
344+
@test VarName(:x) ks
345+
@test VarName(:y) ks
346+
347+
# Observe all.
348+
m = demo2(1000.0, 0.5);
349+
vi = VarInfo(m);
350+
ks = keys(vi)
351+
@test isempty(ks)
352+
353+
# Check values makes sense.
354+
@model function demo2(x, y)
355+
@submodel demo1(x)
356+
y ~ Normal(x)
357+
end;
358+
m = demo2(1000.0, missing);
359+
# Mean of `y` should be close to 1000.
360+
@test abs(mean([VarInfo(m)[VarName(:y)] for i = 1:10]) - 1000) 10;
361+
362+
# Prefixed submodels and usage of submodel return values.
363+
@model function demo_return(x)
364+
x ~ Normal()
365+
return x
366+
end;
367+
368+
@model function demo_useval(x, y)
369+
x1 = @submodel sub1 demo_return(x)
370+
x2 = @submodel sub2 demo_return(y)
371+
372+
z ~ Normal(x1 + x2 + 100, 1.0)
373+
end;
374+
m = demo_useval(missing, missing)
375+
vi = VarInfo(m);
376+
ks = keys(vi)
377+
@test VarName(Symbol("sub1.x")) ks
378+
@test VarName(Symbol("sub2.x")) ks
379+
@test VarName(:z) ks
380+
@test abs(mean([VarInfo(m)[VarName(:z)] for i = 1:10]) - 100) 10
381+
382+
# AR1 model. Dynamic prefixing.
383+
@model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV}
384+
η ~ MvNormal(num_steps, 1.0)
385+
δ = sqrt(1 - α^2)
386+
387+
x = TV(undef, num_steps)
388+
x[1] = η[1]
389+
@inbounds for t = 2:num_steps
390+
x[t] = @. α * x[t - 1] + δ * η[t]
391+
end
392+
393+
return @. μ + σ * x
394+
end
395+
396+
@model function demo(y)
397+
α ~ Uniform()
398+
μ ~ Normal()
399+
σ ~ truncated(Normal(), 0, Inf)
400+
401+
num_steps = length(y[1])
402+
num_obs = length(y)
403+
@inbounds for i = 1:num_obs
404+
x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
405+
y[i] ~ MvNormal(x, 0.1)
406+
end
407+
end;
408+
409+
ys = [randn(10), randn(10)];
410+
m = demo(ys);
411+
vi = VarInfo(m);
412+
413+
for k in [, , , Symbol("ar1_1.η"), Symbol("ar1_2.η")]
414+
@test VarName(k) keys(vi)
415+
end
416+
end
417+
317418
@testset "check_tilde_rhs" begin
318419
@test_throws ArgumentError DynamicPPL.check_tilde_rhs(randn())
319420

0 commit comments

Comments
 (0)