Skip to content

Commit f8bdfaa

Browse files
committed
Make function signature consistent and predictable
1 parent c3d430c commit f8bdfaa

File tree

2 files changed

+55
-33
lines changed

2 files changed

+55
-33
lines changed

src/compiler.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function model(expr)
6666
modelinfo = build_model_info(expr)
6767

6868
# Generate main body
69-
modelinfo[:main_body] = generate_mainbody(modelinfo[:main_body], modelinfo[:args])
69+
modelinfo[:modelbody] = generate_mainbody(modelinfo[:body], modelinfo[:modelargs])
7070

7171
return build_output(modelinfo)
7272
end
@@ -84,7 +84,8 @@ function build_model_info(input_expr)
8484
# Construct model_info dictionary
8585

8686
# Extracting the argument symbols from the model definition
87-
arg_syms = map(modeldef[:args]) do arg
87+
combinedargs = vcat(modeldef[:args], modeldef[:kwargs])
88+
arg_syms = map(combinedargs) do arg
8889
# @model demo(x)
8990
if (arg isa Symbol)
9091
arg
@@ -110,7 +111,7 @@ function build_model_info(input_expr)
110111
)
111112
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
112113
end
113-
args = map(modeldef[:args]) do arg
114+
args = map(combinedargs) do arg
114115
if (arg isa Symbol)
115116
arg
116117
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
@@ -131,7 +132,7 @@ function build_model_info(input_expr)
131132

132133
default_syms = []
133134
default_vals = []
134-
foreach(modeldef[:args]) do arg
135+
foreach(combinedargs) do arg
135136
# @model demo(::Type{T}) where {T}
136137
if MacroTools.@capture(arg, ::Type{T_} = Tval_)
137138
push!(default_syms, T)
@@ -148,15 +149,13 @@ function build_model_info(input_expr)
148149
end
149150
defaults_nt = to_namedtuple_expr(default_syms, default_vals)
150151

151-
model_info = Dict(
152-
:name => modeldef[:name],
153-
:main_body => modeldef[:body],
154-
:arg_syms => arg_syms,
155-
:args_nt => args_nt,
156-
:defaults_nt => defaults_nt,
157-
:args => args,
158-
:whereparams => modeldef[:whereparams]
152+
modelderiv = Dict(
153+
:modelargs => args,
154+
:modelargsyms => arg_syms,
155+
:modelargsnt => args_nt,
156+
:modeldefaultsnt => defaults_nt,
159157
)
158+
model_info = merge(modeldef, modelderiv)
160159

161160
return model_info
162161
end
@@ -313,20 +312,18 @@ Builds the output expression.
313312
"""
314313
function build_output(model_info)
315314
# Arguments with default values
316-
args = model_info[:args]
315+
args = model_info[:modelargs]
317316
# Argument symbols without default values
318-
arg_syms = model_info[:arg_syms]
317+
arg_syms = model_info[:modelargsyms]
319318
# Arguments namedtuple
320-
args_nt = model_info[:args_nt]
319+
args_nt = model_info[:modelargsnt]
321320
# Default values of the arguments
322321
# Arguments namedtuple
323-
defaults_nt = model_info[:defaults_nt]
324-
# Where parameters
325-
whereparams = model_info[:whereparams]
322+
defaults_nt = model_info[:modeldefaultsnt]
326323
# Model generator name
327324
model_gen = model_info[:name]
328325
# Main body of the model
329-
main_body = model_info[:main_body]
326+
main_body = model_info[:modelbody]
330327

331328
unwrap_data_expr = Expr(:block)
332329
for var in arg_syms
@@ -335,9 +332,13 @@ function build_output(model_info)
335332
end
336333

337334
@gensym(evaluator, generator)
338-
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
339335
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
340336

337+
# construct the user-facing model generator
338+
model_info[:name] = generator
339+
model_info[:body] = :(return $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor))
340+
generator_expr = MacroTools.combinedef(model_info)
341+
341342
return quote
342343
function $evaluator(
343344
_model::$(DynamicPPL.Model),
@@ -349,8 +350,7 @@ function build_output(model_info)
349350
$main_body
350351
end
351352

352-
$generator($(args...)) = $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor)
353-
$(generator_kw_form...)
353+
$(generator_expr)
354354

355355
$(Base).@__doc__ $model_gen = $model_gen_constructor
356356
end

test/compiler.jl

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,22 @@ end
135135
end
136136
f1_mm = testmodel1(1., 10.)
137137
@test f1_mm() == (1, 10)
138-
f1_mm = testmodel1(x1=1., x2=10.)
138+
139+
# alternatives with keyword arguments
140+
testmodel1kw(; x1, x2) = testmodel1(x1, x2)
141+
f1_mm = testmodel1kw(x1 = 1., x2 = 10.)
142+
@test f1_mm() == (1, 10)
143+
144+
@model function testmodel2(; x1, x2)
145+
s ~ InverseGamma(2,3)
146+
m ~ Normal(0,sqrt(s))
147+
148+
x1 ~ Normal(m, sqrt(s))
149+
x2 ~ Normal(m, sqrt(s))
150+
151+
return x1, x2
152+
end
153+
f1_mm = testmodel2(x1=1., x2=10.)
139154
@test f1_mm() == (1, 10)
140155

141156
@info "Testing the compiler's ability to catch bad models..."
@@ -199,7 +214,7 @@ end
199214
x ~ Bernoulli(0.5)
200215
return x
201216
end
202-
@test_throws UndefKeywordError testmodel()
217+
@test_throws MethodError testmodel()
203218

204219
# Test missing initialization for vector observation turned parameter
205220
@model testmodel(x) = begin
@@ -266,7 +281,7 @@ end
266281
chain = sample(gauss(x), PG(10), 10)
267282
chain = sample(gauss(x), SMC(), 10)
268283

269-
@model gauss2(x, ::Type{TV}=Vector{Float64}) where {TV} = begin
284+
@model function gauss2(::Type{TV} = Vector{Float64}; x) where {TV}
270285
priors = TV(undef, 2)
271286
priors[1] ~ InverseGamma(2,3) # s
272287
priors[2] ~ Normal(0, sqrt(priors[1])) # m
@@ -276,10 +291,11 @@ end
276291
priors
277292
end
278293

279-
chain = sample(gauss2(x), PG(10), 10)
280-
chain = sample(gauss2(x=x, TV=Vector{Float64}), PG(10), 10)
281-
chain = sample(gauss2(x), SMC(), 10)
282-
chain = sample(gauss2(x=x, TV=Vector{Float64}), SMC(), 10)
294+
chain = sample(gauss2(x = x), PG(10), 10)
295+
chain = sample(gauss2(x = x), SMC(), 10)
296+
297+
chain = sample(gauss2(Vector{Float64}; x = x), PG(10), 10)
298+
chain = sample(gauss2(Vector{Float64}; x = x), SMC(), 10)
283299
end
284300
@testset "new interface" begin
285301
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]
@@ -494,7 +510,7 @@ end
494510
setchunksize(N)
495511
alg = HMC(0.01, 5)
496512
x = randn(1000)
497-
@model vdemo1(::Type{T}=Float64) where {T} = begin
513+
@model function vdemo1(::Type{T}=Float64) where {T}
498514
x = Vector{T}(undef, N)
499515
for i = 1:N
500516
x[i] ~ Normal(0, sqrt(4))
@@ -503,7 +519,9 @@ end
503519

504520
t_loop = @elapsed res = sample(vdemo1(), alg, 250)
505521
t_loop = @elapsed res = sample(vdemo1(Float64), alg, 250)
506-
t_loop = @elapsed res = sample(vdemo1(T=Float64), alg, 250)
522+
523+
vdemo1kw(; T) = vdemo1(T)
524+
t_loop = @elapsed res = sample(vdemo1kw(T = Float64), alg, 250)
507525

508526
@model vdemo2(::Type{T}=Float64) where {T <: Real} = begin
509527
x = Vector{T}(undef, N)
@@ -512,7 +530,9 @@ end
512530

513531
t_vec = @elapsed res = sample(vdemo2(), alg, 250)
514532
t_vec = @elapsed res = sample(vdemo2(Float64), alg, 250)
515-
t_vec = @elapsed res = sample(vdemo2(T=Float64), alg, 250)
533+
534+
vdemo2kw(; T) = vdemo2(T)
535+
t_vec = @elapsed res = sample(vdemo2kw(T = Float64), alg, 250)
516536

517537
@model vdemo3(::Type{TV}=Vector{Float64}) where {TV <: AbstractVector} = begin
518538
x = TV(undef, N)
@@ -521,7 +541,9 @@ end
521541

522542
sample(vdemo3(), alg, 250)
523543
sample(vdemo3(Vector{Float64}), alg, 250)
524-
sample(vdemo3(TV=Vector{Float64}), alg, 250)
544+
545+
vdemo3kw(; T) = vdemo3(T)
546+
sample(vdemo3kw(T = Vector{Float64}), alg, 250)
525547
end
526548
@testset "var name splitting" begin
527549
var_expr = :(x)

0 commit comments

Comments
 (0)