Skip to content

Commit 0ffa0e5

Browse files
torfjeldeyebai
andcommitted
Fix for default arguments (#474)
Currently on master: ``` julia julia> using DynamicPPL, Distributions, LinearAlgebra julia> @model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I) test_defaults (generic function with 3 methods) julia> test_defaults(missing, 2)() ERROR: MethodError: no method matching length(::Missing) ... ``` On this branch: ``` julia julia> using DynamicPPL, Distributions, LinearAlgebra julia> @model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I) test_defaults (generic function with 3 methods) julia> test_defaults(missing, 2)() 2-element Vector{Float64}: 0.3028550279042859 0.6130034982853375 ``` On master, this is caused by ``` julia julia> @macroexpand @model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I) quote function test_defaults(__model__::Model, __varinfo__::AbstractVarInfo, __context__::AbstractPPL.AbstractContext, x, n; ) ... end begin $(Expr(:meta, :doc)) function test_defaults(x, n = length(x); ) # This is the offending line. return (Model)(test_defaults, NamedTuple{(:x, :n)}((x, n)), NamedTuple{(:n,)}((length(x),))) end end end ``` And subsequently fixed in this PR, in which case the above is ```julia julia> @macroexpand @model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I) quote function test_defaults(__model__::Model, __varinfo__::AbstractVarInfo, __context__::AbstractPPL.AbstractContext, x, n; ) ... end begin $(Expr(:meta, :doc)) function test_defaults(x, n = length(x); ) # Now we respect the argument `n`. return (Model)(test_defaults, NamedTuple{(:x, :n)}((x, n)), NamedTuple{(:n,)}((n,))) end end end ``` Co-authored-by: Hong Ge <[email protected]>
1 parent 409c0ed commit 0ffa0e5

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.22.2"
3+
version = "0.22.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ function build_model_info(input_expr)
287287
end
288288

289289
# Build named tuple expression of the argument symbols with default values.
290-
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)
290+
defaults_namedtuple = to_namedtuple_expr(default_syms)
291291

292292
modelinfo = Dict(
293293
:allargs_exprs => allargs_exprs,

src/utils.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,15 @@ function getargs_assignment(expr::Expr)
161161
end
162162
end
163163

164-
function to_namedtuple_expr(syms, vals=syms)
164+
function to_namedtuple_expr(syms)
165+
length(syms) == 0 && return :(NamedTuple())
166+
167+
names_expr = Expr(:tuple, QuoteNode.(syms)...)
168+
return :(NamedTuple{$names_expr}(($(syms...),)))
169+
end
170+
171+
# FIXME: the prob macro still uses this.
172+
function to_namedtuple_expr(syms, vals)
165173
length(syms) == 0 && return :(NamedTuple())
166174

167175
names_expr = Expr(:tuple, QuoteNode.(syms)...)

test/model.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,9 @@ end
149149
Random.seed!(1776)
150150
@test rand(Dict, model) == sample_dict
151151
end
152+
153+
@testset "default arguments" begin
154+
@model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I)
155+
@test length(test_defaults(missing, 2)()) == 2
156+
end
152157
end

0 commit comments

Comments
 (0)