Skip to content

Commit 0d06ec2

Browse files
authored
Merge pull request #177 from pebeto/dev
Adding flat_params function
2 parents 60ea476 + 9cc45b8 commit 0d06ec2

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

src/parameter_inspection.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,33 @@ function params(m, ::Val{true})
3030
return NamedTuple{fields}(Tuple([params(getfield(m, field)) for field in fields]))
3131
end
3232

33+
isamodel(::Any) = false
34+
isamodel(::Model) = true
3335

36+
"""
37+
flat_params(m::Model)
38+
39+
Recursively convert any object subtyping `Model` into a named tuple, keyed on
40+
the property names of `m`. The named tuple is possibly nested because
41+
`flat_params` is recursively applied to the property values, which themselves
42+
might subtype `Model`.
43+
44+
For most `Model` objects, properties are synonymous with fields, but this is
45+
not a hard requirement.
3446
47+
julia> flat_params(EnsembleModel(atom=ConstantClassifier()))
48+
(atom = (target_type = Bool,),
49+
weights = Float64[],
50+
bagging_fraction = 0.8,
51+
rng_seed = 0,
52+
n = 100,
53+
parallel = true,)
3554
55+
"""
56+
flat_params(m; prefix="") = flat_params(m, Val(isamodel(m)); prefix=prefix)
57+
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
58+
function flat_params(m, ::Val{true}; prefix="")
59+
fields = propertynames(m)
60+
prefix = prefix == "" ? "" : prefix * "__"
61+
merge([flat_params(getproperty(m, field); prefix="$(prefix)$(field)") for field in fields]...)
62+
end

test/parameter_inspection.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ end
1212

1313
MLJModelInterface.istransparent(::Transparent) = true
1414

15-
struct Dummy <:MLJType
15+
struct Dummy <: MLJType
1616
t::Transparent
1717
o::Opaque
1818
n::Integer
1919
end
2020

21+
2122
@testset "params method" begin
2223

2324
t= Transparent(6, Opaque(5))
@@ -32,4 +33,31 @@ end
3233
n = 42
3334
)
3435
end
36+
37+
struct ChildModel <: Model
38+
x::Int
39+
y::String
40+
end
41+
42+
struct ParentModel <: Model
43+
x::Int
44+
y::String
45+
first_child::ChildModel
46+
second_child::ChildModel
47+
end
48+
49+
@testset "flat_params method" begin
50+
51+
m = ParentModel(1, "parent", ChildModel(2, "child1"),
52+
ChildModel(3, "child2"))
53+
54+
@test MLJModelInterface.flat_params(m) == (
55+
x = 1,
56+
y = "parent",
57+
first_child__x = 2,
58+
first_child__y = "child1",
59+
second_child__x = 3,
60+
second_child__y = "child2"
61+
)
62+
end
3563
true

0 commit comments

Comments
 (0)