Skip to content

Commit 485f839

Browse files
authored
Merge pull request #82 from alan-turing-institute/dev
For a 0.3.8 release
2 parents 996baac + e0c5d03 commit 485f839

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.3.7"
4+
version = "0.3.8"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/model_traits.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,14 @@ docstring(M::Type{<:MLJType}) = name(M)
3636
docstring(M::Type{<:Model}) = "$(name(M)) from $(package_name(M)).jl.\n" *
3737
"[Documentation]($(package_url(M)))."
3838
# "derived" traits:
39-
name(M::Type) = string(M)
40-
name(M::Type{<:MLJType}) = split(string(coretype(M)), '.')[end] |> String
41-
39+
function _coretype(M)
40+
if isdefined(M, :name)
41+
return M.name.name
42+
else
43+
return _coretype(M.body)
44+
end
45+
end
46+
name(M::Type) = string(_coretype(M))
4247
is_supervised(::Type) = false
4348
is_supervised(::Type{<:Supervised}) = true
4449
prediction_type(::Type) = :unknown # used for measures too

src/utils.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,3 @@ if VERSION < v"1.1"
22
fieldtypes(t) = Tuple(fieldtype(t, i) for i = 1:fieldcount(t))
33
end
44

5-
function coretype(M)
6-
if isdefined(M, :name)
7-
return M.name
8-
else
9-
return coretype(M.body)
10-
end
11-
end

test/model_traits.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,27 @@ bar(::P1) = nothing
6464
@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])
6565
end
6666

67-
struct FooMeasure <: MLJType end
67+
module Fruit
68+
69+
import MLJModelInterface.MLJType
70+
71+
struct Banana <: MLJType end
72+
struct Apple end
73+
74+
end
75+
76+
import .Banana
6877

6978
@testset "extras" begin
7079
@test docstring(Float64) == "Float64"
71-
@test docstring(FooMeasure) == "FooMeasure"
80+
@test docstring(Fruit.Banana) == "Banana"
7281
@test name(Float64) == "Float64"
7382

7483
df = DataFrame(a=randn(2), b=randn(2))
75-
@test string(M.coretype(typeof(df))) == "DataFrame"
84+
@test M.name(typeof(df)) == "DataFrame"
85+
@test M.name(df) == "DataFrame"
86+
@test M.name(Fruit.Banana) == "Banana"
87+
@test M.name(Fruit.Banana()) == "Banana"
88+
@test M.name(Fruit.Apple) == "Apple"
89+
@test M.name(Fruit.Apple()) == "Apple"
7690
end

0 commit comments

Comments
 (0)