Skip to content

Commit e0c5d03

Browse files
committed
2 parents c5da2c1 + c878e38 commit e0c5d03

File tree

3 files changed

+24
-19
lines changed

3 files changed

+24
-19
lines changed

src/model_traits.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ docstring(M::Type{<:MLJType}) = name(M)
3636
docstring(M::Type{<:Model}) = "$(name(M)) from $(package_name(M)).jl.\n" *
3737
"[Documentation]($(package_url(M)))."
3838
# "derived" traits:
39-
typename(s) = replace(s, r"typename\((.*?)\)" => s"\1")
40-
name(M::Type) = string(M) |> typename
41-
name(M::Type{<:MLJType}) = split(string(coretype(M)), '.')[end] |> String |> typename
42-
43-
39+
function _coretype(M)
40+
if isdefined(M, :name)
41+
return M.name.name
42+
else
43+
return _coretype(M.body)
44+
end
45+
end
46+
name(M::Type) = string(_coretype(M))
4447
is_supervised(::Type) = false
4548
is_supervised(::Type{<:Supervised}) = true
4649
prediction_type(::Type) = :unknown # used for measures too

src/utils.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,3 @@ if VERSION < v"1.1"
22
fieldtypes(t) = Tuple(fieldtype(t, i) for i = 1:fieldcount(t))
33
end
44

5-
function coretype(M)
6-
if isdefined(M, :name)
7-
return M.name
8-
else
9-
return coretype(M.body)
10-
end
11-
end

test/model_traits.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,27 @@ bar(::P1) = nothing
6464
@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])
6565
end
6666

67-
struct FooMeasure <: MLJType end
67+
module Fruit
68+
69+
import MLJModelInterface.MLJType
70+
71+
struct Banana <: MLJType end
72+
struct Apple end
73+
74+
end
75+
76+
import .Banana
6877

6978
@testset "extras" begin
7079
@test docstring(Float64) == "Float64"
71-
@test docstring(FooMeasure) == "FooMeasure"
80+
@test docstring(Fruit.Banana) == "Banana"
7281
@test name(Float64) == "Float64"
7382

7483
df = DataFrame(a=randn(2), b=randn(2))
75-
@static if VERSION < v"1.6-"
76-
@test string(M.coretype(typeof(df))) == "DataFrame"
77-
else
78-
@test string(M.coretype(typeof(df))) == "typename(DataFrame)"
79-
end
8084
@test M.name(typeof(df)) == "DataFrame"
85+
@test M.name(df) == "DataFrame"
86+
@test M.name(Fruit.Banana) == "Banana"
87+
@test M.name(Fruit.Banana()) == "Banana"
88+
@test M.name(Fruit.Apple) == "Apple"
89+
@test M.name(Fruit.Apple()) == "Apple"
8190
end

0 commit comments

Comments
 (0)