Skip to content

Commit c85e73a

Browse files
authored
Merge pull request #109 from TuringLang/function_signature
Define model generator exactly as specified by the user
2 parents c40bde6 + 7b62284 commit c85e73a

File tree

2 files changed

+55
-33
lines changed

2 files changed

+55
-33
lines changed

src/compiler.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function model(expr, warn)
6969
modelinfo = build_model_info(expr)
7070

7171
# Generate main body
72-
modelinfo[:main_body] = generate_mainbody(modelinfo[:main_body], modelinfo[:args], warn)
72+
modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs], warn)
7373

7474
return build_output(modelinfo)
7575
end
@@ -87,7 +87,8 @@ function build_model_info(input_expr)
8787
# Construct model_info dictionary
8888

8989
# Extracting the argument symbols from the model definition
90-
arg_syms = map(modeldef[:args]) do arg
90+
combinedargs = vcat(modeldef[:args], modeldef[:kwargs])
91+
arg_syms = map(combinedargs) do arg
9192
# @model demo(x)
9293
if (arg isa Symbol)
9394
arg
@@ -113,7 +114,7 @@ function build_model_info(input_expr)
113114
)
114115
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
115116
end
116-
args = map(modeldef[:args]) do arg
117+
args = map(combinedargs) do arg
117118
if (arg isa Symbol)
118119
arg
119120
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
@@ -134,7 +135,7 @@ function build_model_info(input_expr)
134135

135136
default_syms = []
136137
default_vals = []
137-
foreach(modeldef[:args]) do arg
138+
foreach(combinedargs) do arg
138139
# @model demo(::Type{T}) where {T}
139140
if MacroTools.@capture(arg, ::Type{T_} = Tval_)
140141
push!(default_syms, T)
@@ -151,15 +152,13 @@ function build_model_info(input_expr)
151152
end
152153
defaults_nt = to_namedtuple_expr(default_syms, default_vals)
153154

154-
model_info = Dict(
155-
:name => modeldef[:name],
156-
:main_body => modeldef[:body],
157-
:arg_syms => arg_syms,
158-
:args_nt => args_nt,
159-
:defaults_nt => defaults_nt,
160-
:args => args,
161-
:whereparams => modeldef[:whereparams]
155+
modelderiv = Dict(
156+
:modelargs => args,
157+
:modelargsyms => arg_syms,
158+
:modelargsnt => args_nt,
159+
:modeldefaultsnt => defaults_nt,
162160
)
161+
model_info = merge(modeldef, modelderiv)
163162

164163
return model_info
165164
end
@@ -319,20 +318,18 @@ Builds the output expression.
319318
"""
320319
function build_output(model_info)
321320
# Arguments with default values
322-
args = model_info[:args]
321+
args = model_info[:modelargs]
323322
# Argument symbols without default values
324-
arg_syms = model_info[:arg_syms]
323+
arg_syms = model_info[:modelargsyms]
325324
# Arguments namedtuple
326-
args_nt = model_info[:args_nt]
325+
args_nt = model_info[:modelargsnt]
327326
# Default values of the arguments
328327
# Arguments namedtuple
329-
defaults_nt = model_info[:defaults_nt]
330-
# Where parameters
331-
whereparams = model_info[:whereparams]
328+
defaults_nt = model_info[:modeldefaultsnt]
332329
# Model generator name
333330
model_gen = model_info[:name]
334331
# Main body of the model
335-
main_body = model_info[:main_body]
332+
main_body = model_info[:modelbody]
336333

337334
unwrap_data_expr = Expr(:block)
338335
for var in arg_syms
@@ -341,9 +338,13 @@ function build_output(model_info)
341338
end
342339

343340
@gensym(evaluator, generator)
344-
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
345341
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
346342

343+
# construct the user-facing model generator
344+
model_info[:name] = generator
345+
model_info[:body] = :(return $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor))
346+
generator_expr = MacroTools.combinedef(model_info)
347+
347348
return quote
348349
function $evaluator(
349350
_rng::$(Random.AbstractRNG),
@@ -356,8 +357,7 @@ function build_output(model_info)
356357
$main_body
357358
end
358359

359-
$generator($(args...)) = $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor)
360-
$(generator_kw_form...)
360+
$(generator_expr)
361361

362362
$(Base).@__doc__ $model_gen = $model_gen_constructor
363363
end

test/compiler.jl

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,22 @@ end
135135
end
136136
f1_mm = testmodel1(1., 10.)
137137
@test f1_mm() == (1, 10)
138-
f1_mm = testmodel1(x1=1., x2=10.)
138+
139+
# alternatives with keyword arguments
140+
testmodel1kw(; x1, x2) = testmodel1(x1, x2)
141+
f1_mm = testmodel1kw(x1 = 1., x2 = 10.)
142+
@test f1_mm() == (1, 10)
143+
144+
@model function testmodel2(; x1, x2)
145+
s ~ InverseGamma(2,3)
146+
m ~ Normal(0,sqrt(s))
147+
148+
x1 ~ Normal(m, sqrt(s))
149+
x2 ~ Normal(m, sqrt(s))
150+
151+
return x1, x2
152+
end
153+
f1_mm = testmodel2(x1=1., x2=10.)
139154
@test f1_mm() == (1, 10)
140155

141156
@info "Testing the compiler's ability to catch bad models..."
@@ -199,7 +214,7 @@ end
199214
x ~ Bernoulli(0.5)
200215
return x
201216
end
202-
@test_throws UndefKeywordError testmodel()
217+
@test_throws MethodError testmodel()
203218

204219
# Test missing initialization for vector observation turned parameter
205220
@model testmodel(x) = begin
@@ -285,7 +300,7 @@ end
285300
chain = sample(gauss(x), PG(10), 10)
286301
chain = sample(gauss(x), SMC(), 10)
287302

288-
@model gauss2(x, ::Type{TV}=Vector{Float64}) where {TV} = begin
303+
@model function gauss2(::Type{TV} = Vector{Float64}; x) where {TV}
289304
priors = TV(undef, 2)
290305
priors[1] ~ InverseGamma(2,3) # s
291306
priors[2] ~ Normal(0, sqrt(priors[1])) # m
@@ -295,10 +310,11 @@ end
295310
priors
296311
end
297312

298-
chain = sample(gauss2(x), PG(10), 10)
299-
chain = sample(gauss2(x=x, TV=Vector{Float64}), PG(10), 10)
300-
chain = sample(gauss2(x), SMC(), 10)
301-
chain = sample(gauss2(x=x, TV=Vector{Float64}), SMC(), 10)
313+
chain = sample(gauss2(x = x), PG(10), 10)
314+
chain = sample(gauss2(x = x), SMC(), 10)
315+
316+
chain = sample(gauss2(Vector{Float64}; x = x), PG(10), 10)
317+
chain = sample(gauss2(Vector{Float64}; x = x), SMC(), 10)
302318
end
303319
@testset "new interface" begin
304320
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
@@ -513,7 +529,7 @@ end
513529
setchunksize(N)
514530
alg = HMC(0.01, 5)
515531
x = randn(1000)
516-
@model vdemo1(::Type{T}=Float64) where {T} = begin
532+
@model function vdemo1(::Type{T}=Float64) where {T}
517533
x = Vector{T}(undef, N)
518534
for i = 1:N
519535
x[i] ~ Normal(0, sqrt(4))
@@ -522,7 +538,9 @@ end
522538

523539
t_loop = @elapsed res = sample(vdemo1(), alg, 250)
524540
t_loop = @elapsed res = sample(vdemo1(Float64), alg, 250)
525-
t_loop = @elapsed res = sample(vdemo1(T=Float64), alg, 250)
541+
542+
vdemo1kw(; T) = vdemo1(T)
543+
t_loop = @elapsed res = sample(vdemo1kw(T = Float64), alg, 250)
526544

527545
@model vdemo2(::Type{T}=Float64) where {T <: Real} = begin
528546
x = Vector{T}(undef, N)
@@ -531,7 +549,9 @@ end
531549

532550
t_vec = @elapsed res = sample(vdemo2(), alg, 250)
533551
t_vec = @elapsed res = sample(vdemo2(Float64), alg, 250)
534-
t_vec = @elapsed res = sample(vdemo2(T=Float64), alg, 250)
552+
553+
vdemo2kw(; T) = vdemo2(T)
554+
t_vec = @elapsed res = sample(vdemo2kw(T = Float64), alg, 250)
535555

536556
@model vdemo3(::Type{TV}=Vector{Float64}) where {TV <: AbstractVector} = begin
537557
x = TV(undef, N)
@@ -540,7 +560,9 @@ end
540560

541561
sample(vdemo3(), alg, 250)
542562
sample(vdemo3(Vector{Float64}), alg, 250)
543-
sample(vdemo3(TV=Vector{Float64}), alg, 250)
563+
564+
vdemo3kw(; T) = vdemo3(T)
565+
sample(vdemo3kw(T = Vector{Float64}), alg, 250)
544566
end
545567
@testset "var name splitting" begin
546568
var_expr = :(x)

0 commit comments

Comments
 (0)