Skip to content

Commit 23c9573

Browse files
authored
Merge pull request #179 from JuliaAI/flat_params-hotfix
flat_params hotfix respect to empty models
2 parents fc6323b + f9de574 commit 23c9573

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/parameter_inspection.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ function params(m, ::Val{true})
3030
return NamedTuple{fields}(Tuple([params(getfield(m, field)) for field in fields]))
3131
end
3232

33-
isamodel(::Any) = false
34-
isamodel(::Model) = true
33+
isnotaleaf(::Any) = false
34+
isnotaleaf(m::Model) = length(propertynames(m)) > 0
3535

3636
"""
3737
flat_params(m::Model)
@@ -53,7 +53,7 @@ not a hard requirement.
5353
parallel = true,)
5454
5555
"""
56-
flat_params(m; prefix="") = flat_params(m, Val(isamodel(m)); prefix=prefix)
56+
flat_params(m; prefix="") = flat_params(m, Val(isnotaleaf(m)); prefix=prefix)
5757
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
5858
function flat_params(m, ::Val{true}; prefix="")
5959
fields = propertynames(m)

test/parameter_inspection.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ end
3535
end
3636

3737
struct ChildModel <: Model
38-
x::Int
39-
y::String
38+
r::Int
39+
s
4040
end
4141

4242
struct ParentModel <: Model
@@ -46,18 +46,20 @@ struct ParentModel <: Model
4646
second_child::ChildModel
4747
end
4848

49+
struct Missy <: Model end
50+
4951
@testset "flat_params method" begin
5052

5153
m = ParentModel(1, "parent", ChildModel(2, "child1"),
52-
ChildModel(3, "child2"))
54+
ChildModel(3, Missy()))
5355

5456
@test MLJModelInterface.flat_params(m) == (
5557
x = 1,
5658
y = "parent",
57-
first_child__x = 2,
58-
first_child__y = "child1",
59-
second_child__x = 3,
60-
second_child__y = "child2"
59+
first_child__r = 2,
60+
first_child__s = "child1",
61+
second_child__r = 3,
62+
second_child__s = Missy()
6163
)
6264
end
6365
true

0 commit comments

Comments
 (0)