Skip to content

Commit 347cc85

Browse files
authored
Merge pull request #51 from alan-turing-institute/dev
For a O.3.0 release
2 parents a064ba3 + 9429cfa commit 347cc85

File tree

6 files changed

+73
-7
lines changed

6 files changed

+73
-7
lines changed

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.2.8"
4+
version = "0.3.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -16,10 +16,9 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
1616
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1717
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
1818
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
19-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2019
MLJScientificTypes = "2e2323e0-db8b-457b-ae0d-bdfb3bc63afd"
2120
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2221
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2322

2423
[targets]
25-
test = ["Test", "Tables", "Distances", "CategoricalArrays", "InteractiveUtils", "DataFrames", "MLJScientificTypes", "MLJBase"]
24+
test = ["Test", "Tables", "Distances", "CategoricalArrays", "InteractiveUtils", "DataFrames", "MLJScientificTypes"]

src/MLJModelInterface.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ export MLJType, Model, Supervised, Unsupervised,
1616
Probabilistic, Deterministic, Interval, Static,
1717
UnivariateFinite
1818

19+
# parameter_inspection:
20+
export params
21+
1922
# model constructor + metadata
2023
export @mlj_model, metadata_pkg, metadata_model
2124

@@ -89,9 +92,9 @@ abstract type Static <: Unsupervised end
8992
# includes
9093

9194
include("utils.jl")
95+
include("parameter_inspection.jl")
9296
include("data_utils.jl")
9397
include("metadata_utils.jl")
94-
9598
include("model_traits.jl")
9699
include("model_def.jl")
97100
include("model_api.jl")

src/parameter_inspection.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
istransparent(::Any) = false
2+
istransparent(::MLJType) = true
3+
4+
"""
5+
params(m::MLJType)
6+
7+
Recursively convert any transparent object `m` into a named tuple,
8+
keyed on the fields of `m`. An object is *transparent* if
9+
`MLJModelInterface.istransparent(m) == true`. The named tuple is
10+
possibly nested because `params` is recursively applied to the field
11+
values, which themselves might be transparent.
12+
13+
Most objects of type `MLJType` are transparent.
14+
15+
julia> params(EnsembleModel(atom=ConstantClassifier()))
16+
(atom = (target_type = Bool,),
17+
weights = Float64[],
18+
bagging_fraction = 0.8,
19+
rng_seed = 0,
20+
n = 100,
21+
parallel = true,)
22+
23+
"""
24+
params(m) = params(m, Val(istransparent(m)))
25+
params(m, ::Val{false}) = m
26+
function params(m, ::Val{true})
27+
fields = fieldnames(typeof(m))
28+
NamedTuple{fields}(Tuple([params(getfield(m, field)) for field in fields]))
29+
end
30+
31+
32+
33+

test/data_utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ end
7070
setfull()
7171
ary = rand(10, 3)
7272
@test M.schema(ary) === nothing
73-
M.schema(::FI, ::Val{:table}, X; kw...) = MLJBase.schema(X; kw...) # this would be defined in MLJBase.jl
73+
M.schema(::FI, ::Val{:table}, X; kw...) =
74+
MLJScientificTypes.schema(X; kw...)
7475
df = DataFrame(A = rand(10), B = categorical(rand('a':'c', 10)))
7576
sch = M.schema(df)
7677
@test sch.names == (:A, :B)

test/parameter_inspection.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using Test
2+
using MLJModelInterface
3+
4+
struct Opaque
5+
a::Int
6+
end
7+
8+
struct Transparent
9+
A::Int
10+
B::Opaque
11+
end
12+
13+
MLJModelInterface.istransparent(::Transparent) = true
14+
15+
struct Dummy <:MLJType
16+
t::Transparent
17+
o::Opaque
18+
n::Integer
19+
end
20+
21+
@testset "params method" begin
22+
23+
t= Transparent(6, Opaque(5))
24+
m = Dummy(t, Opaque(7), 42)
25+
26+
@test params(m) == (t = (A = 6,
27+
B = Opaque(5)),
28+
o = Opaque(7),
29+
n = 42)
30+
end
31+
true

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test, MLJModelInterface
22
using ScientificTypes, MLJScientificTypes
33
using Tables, Distances, CategoricalArrays, InteractiveUtils
44
import DataFrames: DataFrame
5-
import MLJBase
65

76
const M = MLJModelInterface
87
const FI = M.FullInterface
@@ -12,9 +11,9 @@ setlight() = M.set_interface_mode(M.LightInterface())
1211
setfull() = M.set_interface_mode(M.FullInterface())
1312

1413
include("mode.jl")
14+
include("parameter_inspection.jl")
1515
include("data_utils.jl")
1616
include("metadata_utils.jl")
17-
1817
include("model_def.jl")
1918
include("model_api.jl")
2019
include("model_traits.jl")

0 commit comments

Comments
 (0)