Skip to content

Commit ad82652

Browse files
authored
Merge pull request #181 from JuliaAI/dev
For a 1.9.2 release
2 parents 776852d + ffb47b1 commit ad82652

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.9.1"
4+
version = "1.9.2"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/model_def.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ function _process_model_def(modl, ex)
6262
if line.head == :(=) # assignment for default
6363
default = line.args[2]
6464
# if a constraint is given (value::constraint)
65-
if default isa Expr && length(default.args) > 1
65+
if default isa Expr && default.head == :(::)
6666
constraints[param] = default.args[2]
6767
# now discard the constraint to keep only the value
6868
default = default.args[1]
6969
end
70-
defaults[param] = default # this will be a value not an expr
70+
defaults[param] = default
7171

7272
# name or name::Type (for the constructor)
7373
ex.args[3].args[i] = line.args[1]

test/model_def.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,34 @@ end
151151
@test Cc().a === nothing
152152
@test Cd().a === missing
153153
end
154+
155+
@testset "Expression defaults" begin
156+
# Should work with and without constraint:
157+
@mlj_model mutable struct Foo1
158+
a::Vector{Int} = [1, 2, 3]
159+
end
160+
@test Foo1().a == [1, 2, 3]
161+
@mlj_model mutable struct Foo2
162+
a::Vector{Int} = [1, 2, 3]::(true)
163+
end
164+
@test Foo2().a == [1, 2, 3]
165+
166+
# Constraints applied
167+
@mlj_model mutable struct Foo3
168+
a::Vector{Int} = [1, 2, 3]::(all(>(0), _))
169+
end
170+
@test redirect_stderr(devnull) do
171+
Foo3(; a = [-1]).a == [1, 2, 3]
172+
end
173+
174+
# Negative number:
175+
@mlj_model mutable struct Foo4
176+
a::Float64 = -1.0
177+
end
178+
@test Foo4().a === -1.0
179+
@mlj_model mutable struct Foo5
180+
a::Float64 = (-1.0)::(true)
181+
end
182+
@test Foo5().a == -1.0
183+
184+
end

0 commit comments

Comments
 (0)