Skip to content

Commit 776852d

Browse files
authored
Merge pull request #180 from JuliaAI/dev
For a 1.9.1 release
2 parents fe9492d + d33265f commit 776852d

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.9.0"
4+
version = "1.9.1"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/parameter_inspection.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ function params(m, ::Val{true})
3030
return NamedTuple{fields}(Tuple([params(getfield(m, field)) for field in fields]))
3131
end
3232

33-
isamodel(::Any) = false
34-
isamodel(::Model) = true
33+
isnotaleaf(::Any) = false
34+
isnotaleaf(m::Model) = length(propertynames(m)) > 0
3535

3636
"""
3737
flat_params(m::Model)
@@ -53,7 +53,7 @@ not a hard requirement.
5353
parallel = true,)
5454
5555
"""
56-
flat_params(m; prefix="") = flat_params(m, Val(isamodel(m)); prefix=prefix)
56+
flat_params(m; prefix="") = flat_params(m, Val(isnotaleaf(m)); prefix=prefix)
5757
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
5858
function flat_params(m, ::Val{true}; prefix="")
5959
fields = propertynames(m)

test/parameter_inspection.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ end
3535
end
3636

3737
struct ChildModel <: Model
38-
x::Int
39-
y::String
38+
r::Int
39+
s
4040
end
4141

4242
struct ParentModel <: Model
@@ -46,18 +46,20 @@ struct ParentModel <: Model
4646
second_child::ChildModel
4747
end
4848

49+
struct Missy <: Model end
50+
4951
@testset "flat_params method" begin
5052

5153
m = ParentModel(1, "parent", ChildModel(2, "child1"),
52-
ChildModel(3, "child2"))
54+
ChildModel(3, Missy()))
5355

5456
@test MLJModelInterface.flat_params(m) == (
5557
x = 1,
5658
y = "parent",
57-
first_child__x = 2,
58-
first_child__y = "child1",
59-
second_child__x = 3,
60-
second_child__y = "child2"
59+
first_child__r = 2,
60+
first_child__s = "child1",
61+
second_child__r = 3,
62+
second_child__s = Missy()
6163
)
6264
end
6365
true

0 commit comments

Comments
 (0)