Skip to content

Commit e9d85d4

Browse files
committed
**BREAKING** Make evaluator a method of the model function (#316)
This PR implements the idea in https://github.com/TuringLang/DynamicPPL.jl/pull/314/files#r693968700. All DynamicPPL tests pass, I only had to update the doctests (since the number of methods that are created by `@model` changed from 1 to 2). **This PR is breaking:** it breaks deserialization of models since Julia can't deserialize generic functions (see below). It allows to do the following: ```julia julia> using DynamicPPL, Distributions julia> @macroexpand @model demo() = x ~ Normal() quote function demo(__model__::Model, __varinfo__::AbstractVarInfo, __context__::DynamicPPL.AbstractContext; ) #= REPL[4]:1 =# begin var"##vn#257" = (VarName){:x}() var"##inds#258" = () var"##isassumption#259" = begin let var"##vn#260" = (VarName){:x}() if (DynamicPPL.contextual_isassumption)(__context__, var"##vn#260") if !((DynamicPPL.inargnames)(var"##vn#260", __model__)) || (DynamicPPL.inmissings)(var"##vn#260", __model__) true else x === missing end else false end end end if var"##isassumption#259" x = (DynamicPPL.tilde_assume!)(__context__, (DynamicPPL.unwrap_right_vn)((DynamicPPL.check_tilde_rhs)(Normal()), var"##vn#257")..., var"##inds#258", __varinfo__) else if !((DynamicPPL.inargnames)(var"##vn#257", __model__)) x = (DynamicPPL.getvalue_nested)(__context__, var"##vn#257") end (DynamicPPL.tilde_observe!)(__context__, (DynamicPPL.check_tilde_rhs)(Normal()), x, var"##vn#257", var"##inds#258", __varinfo__) end end end begin $(Expr(:meta, :doc)) function demo(; ) #= REPL[4]:1 =# return (Model)(:demo, demo, NamedTuple(), NamedTuple()) end end end julia> @model demo() = x ~ Normal() demo (generic function with 2 methods) julia> f(x) = false f (generic function with 1 method) julia> f(::Model{typeof(demo)}) = true f (generic function with 2 methods) julia> f(demo()) true julia> demo() isa Model{typeof(demo)} true ``` Co-authored-by: David Widmann <[email protected]>
1 parent 1c93ec3 commit e9d85d4

File tree

5 files changed

+56
-20
lines changed

5 files changed

+56
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.14.2"
3+
version = "0.15.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,7 @@ Builds the output expression.
470470
"""
471471
function build_output(modelinfo, linenumbernode)
472472
## Build the anonymous evaluator from the user-provided model definition.
473-
474-
# Remove the name.
475473
evaluatordef = deepcopy(modelinfo[:modeldef])
476-
delete!(evaluatordef, :name)
477474

478475
# Add the internal arguments to the user-specified arguments (positional + keywords).
479476
evaluatordef[:args] = vcat(
@@ -489,7 +486,13 @@ function build_output(modelinfo, linenumbernode)
489486
evaluatordef[:kwargs] = []
490487

491488
# Replace the user-provided function body with the version created by DynamicPPL.
492-
evaluatordef[:body] = modelinfo[:body]
489+
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
490+
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
491+
# to the call site
492+
evaluatordef[:body] = MacroTools.@q begin
493+
$(linenumbernode)
494+
$(modelinfo[:body])
495+
end
493496

494497
## Build the model function.
495498

@@ -498,24 +501,24 @@ function build_output(modelinfo, linenumbernode)
498501
defaults_namedtuple = modelinfo[:defaults_namedtuple]
499502

500503
# Update the function body of the user-specified model.
501-
# We use a name for the anonymous evaluator that does not conflict with other variables.
502-
modeldef = modelinfo[:modeldef]
503-
@gensym evaluator
504504
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
505505
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
506506
# to the call site
507+
modeldef = modelinfo[:modeldef]
507508
modeldef[:body] = MacroTools.@q begin
508509
$(linenumbernode)
509-
$evaluator = $(MacroTools.combinedef(evaluatordef))
510510
return $(DynamicPPL.Model)(
511511
$(QuoteNode(modeldef[:name])),
512-
$evaluator,
512+
$(modeldef[:name]),
513513
$allargs_namedtuple,
514514
$defaults_namedtuple,
515515
)
516516
end
517517

518-
return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef)))
518+
return MacroTools.@q begin
519+
$(MacroTools.combinedef(evaluatordef))
520+
$(Base).@__doc__ $(MacroTools.combinedef(modeldef))
521+
end
519522
end
520523

521524
function warn_empty(body)

src/model.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ julia> @model function demo()
121121
x ~ Normal(m, 1)
122122
return (; m=m, x=x)
123123
end
124-
demo (generic function with 1 method)
124+
demo (generic function with 2 methods)
125125
126126
julia> model = demo();
127127
@@ -161,7 +161,7 @@ julia> @model function demo_mv(::Type{TV}=Float64) where {TV}
161161
m[2] ~ Normal()
162162
return m
163163
end
164-
demo_mv (generic function with 2 methods)
164+
demo_mv (generic function with 3 methods)
165165
166166
julia> model = demo_mv();
167167
@@ -192,13 +192,13 @@ the use of [`@submodel`](@ref).
192192
193193
```jldoctest condition
194194
julia> @model demo_inner() = m ~ Normal()
195-
demo_inner (generic function with 1 method)
195+
demo_inner (generic function with 2 methods)
196196
197197
julia> @model function demo_outer()
198198
m = @submodel demo_inner()
199199
return m
200200
end
201-
demo_outer (generic function with 1 method)
201+
demo_outer (generic function with 2 methods)
202202
203203
julia> model = demo_outer();
204204
@@ -218,7 +218,7 @@ julia> @model function demo_outer_prefix()
218218
m = @submodel inner demo_inner()
219219
return m
220220
end
221-
demo_outer_prefix (generic function with 1 method)
221+
demo_outer_prefix (generic function with 2 methods)
222222
223223
julia> # This doesn't work now!
224224
conditioned_model = demo_outer_prefix() | (m = 1.0, );
@@ -279,7 +279,7 @@ julia> @model function demo()
279279
x ~ Normal(m, 1)
280280
return (; m=m, x=x)
281281
end
282-
demo (generic function with 1 method)
282+
demo (generic function with 2 methods)
283283
284284
julia> conditioned_model = condition(demo(), m = 1.0, x = 10.0);
285285
@@ -333,7 +333,7 @@ julia> @model function demo()
333333
m ~ Normal()
334334
x ~ Normal(m, 1)
335335
end
336-
demo (generic function with 1 method)
336+
demo (generic function with 2 methods)
337337
338338
julia> m = demo();
339339
@@ -613,7 +613,7 @@ julia> @model function demo(xs)
613613
end
614614
return (m, )
615615
end
616-
demo (generic function with 1 method)
616+
demo (generic function with 2 methods)
617617
618618
julia> model = demo(randn(10));
619619

test/compiler.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ end
3939

4040
return x, y
4141
end
42+
@test length(methods(testmodel_comp)) == 2
4243
testmodel_comp(1.0, 1.2)
4344

4445
# check if drawing from the prior works
4546
@model function testmodel01(x=missing)
4647
x ~ Normal()
4748
return x
4849
end
50+
@test length(methods(testmodel01)) == 3
4951
f0_mm = testmodel01()
5052
@test mean(f0_mm() for _ in 1:1000) 0.0 atol = 0.1
5153

@@ -58,6 +60,7 @@ end
5860
x[2] ~ Normal()
5961
return x
6062
end
63+
@test length(methods(testmodel02)) == 3
6164
f0_mm = testmodel02()
6265
@test all(x -> isapprox(x, 0; atol=0.1), mean(f0_mm() for _ in 1:1000))
6366

@@ -66,6 +69,7 @@ end
6669
return x
6770
end
6871
f01_mm = testmodel03()
72+
@test length(methods(testmodel03)) == 3
6973
@test mean(f01_mm() for _ in 1:1000) 0.5 atol = 0.1
7074

7175
# test if we get the correct return values
@@ -78,6 +82,7 @@ end
7882

7983
return x1, x2
8084
end
85+
@test length(methods(testmodel1)) == 2
8186
f1_mm = testmodel1(1.0, 10.0)
8287
@test f1_mm() == (1, 10)
8388

@@ -95,6 +100,7 @@ end
95100

96101
return x1, x2
97102
end
103+
@test length(methods(testmodel2)) == 2
98104
f1_mm = testmodel2(; x1=1.0, x2=10.0)
99105
@test f1_mm() == (1, 10)
100106

@@ -461,4 +467,31 @@ end
461467
model = @model(x -> (x ~ Normal()))
462468
end
463469
end
470+
471+
@testset "dispatching with model" begin
472+
f(x) = false
473+
474+
@model demo() = x ~ Normal()
475+
@test !f(demo())
476+
f(::Model{typeof(demo)}) = true
477+
@test f(demo())
478+
479+
# Leads to re-definition of `demo` and trait is not affected.
480+
@test length(methods(demo)) == 2
481+
@model demo() = x ~ Normal()
482+
@test length(methods(demo)) == 2
483+
@test f(demo())
484+
485+
# Ensure we can specialize on arguments.
486+
@model demo(x) = x ~ Normal()
487+
length(methods(demo))
488+
@test f(demo(1.0))
489+
f(::Model{typeof(demo),(:x,)}) = false
490+
@test !f(demo(1.0))
491+
@test f(demo()) # should still be `true`
492+
493+
# Set it to `false` again.
494+
f(::Model{typeof(demo),()}) = false
495+
@test !f(demo())
496+
end
464497
end

test/turing/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
66

77
[compat]
8-
DynamicPPL = "0.14"
8+
DynamicPPL = "0.15"
99
Turing = "0.17"
1010
julia = "1.3"

0 commit comments

Comments
 (0)