Skip to content

Commit 9cc45b8

Browse files
committed
using NamedTuple instead of a Dict in flat_params
1 parent c3f0f99 commit 9cc45b8

File tree

3 files changed

+35
-18
lines changed

3 files changed

+35
-18
lines changed

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080
export UnivariateFinite
8181

8282
# parameter_inspection:
83-
export params, flat_params
83+
export params
8484

8585
# model constructor + metadata
8686
export @mlj_model, metadata_pkg, metadata_model

src/parameter_inspection.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ isamodel(::Model) = true
3636
"""
3737
flat_params(m::Model)
3838
39-
Recursively convert any object subtyping `Model` into a dictionary, keyed on
40-
the property names of `m`. The dictionary is possibly nested because
39+
Recursively convert any object subtyping `Model` into a named tuple, keyed on
40+
the property names of `m`. The named tuple is possibly nested because
4141
`flat_params` is recursively applied to the property values, which themselves
4242
might subtype `Model`.
4343
@@ -54,7 +54,7 @@ not a hard requirement.
5454
5555
"""
5656
flat_params(m; prefix="") = flat_params(m, Val(isamodel(m)); prefix=prefix)
57-
flat_params(m, ::Val{false}; prefix="") = Dict(prefix=>m)
57+
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
5858
function flat_params(m, ::Val{true}; prefix="")
5959
fields = propertynames(m)
6060
prefix = prefix == "" ? "" : prefix * "__"

test/parameter_inspection.jl

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
using Test
22
using MLJModelInterface
33

4-
struct Opaque <: Model
4+
struct Opaque
55
a::Int
66
end
77

8-
struct Transparent <: Model
8+
struct Transparent
99
A::Int
1010
B::Opaque
1111
end
1212

1313
MLJModelInterface.istransparent(::Transparent) = true
1414

15-
struct Dummy <: Model
15+
struct Dummy <: MLJType
1616
t::Transparent
1717
o::Opaque
1818
n::Integer
@@ -27,20 +27,37 @@ end
2727
@test params(m) == (
2828
t = (
2929
A = 6,
30-
B = (
31-
a = 5,
32-
)
33-
),
34-
o = (
35-
a = 7,
30+
B = Opaque(5)
3631
),
32+
o = Opaque(7),
3733
n = 42
3834
)
39-
@test flat_params(m) == Dict(
40-
"o__a" => 7,
41-
"t__A" => 6,
42-
"t__B__a" => 5,
43-
"n" => 42
35+
end
36+
37+
struct ChildModel <: Model
38+
x::Int
39+
y::String
40+
end
41+
42+
struct ParentModel <: Model
43+
x::Int
44+
y::String
45+
first_child::ChildModel
46+
second_child::ChildModel
47+
end
48+
49+
@testset "flat_params method" begin
50+
51+
m = ParentModel(1, "parent", ChildModel(2, "child1"),
52+
ChildModel(3, "child2"))
53+
54+
@test MLJModelInterface.flat_params(m) == (
55+
x = 1,
56+
y = "parent",
57+
first_child__x = 2,
58+
first_child__y = "child1",
59+
second_child__x = 3,
60+
second_child__y = "child2"
4461
)
4562
end
4663
true

0 commit comments

Comments
 (0)