Skip to content

Commit 4fdc1ab

Browse files
committed
integrate new traits: predict_scitype, etc
1 parent c7ea24b commit 4fdc1ab

File tree

3 files changed

+139
-0
lines changed

3 files changed

+139
-0
lines changed

src/MLJModelInterface.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ const MODEL_TRAITS = [
44
:input_scitype,
55
:output_scitype,
66
:target_scitype,
7+
:training_scitype,
8+
:predict_scitype,
9+
:transform_scitype,
10+
:inverse_transform_scitype,
711
:is_pure_julia,
812
:package_name,
913
:package_license,
@@ -18,6 +22,7 @@ const MODEL_TRAITS = [
1822
:name,
1923
:is_supervised,
2024
:prediction_type,
25+
:abstract_type,
2126
:implemented_methods,
2227
:hyperparameters,
2328
:hyperparameter_types,

src/model_traits.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,73 @@ StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval
1616
implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M)
1717
implemented_methods(model) = implemented_methods(typeof(model))
1818
implemented_methods(::LightInterface, M) = errlight("implemented_methods")
19+
20+
for M in ABSTRACT_MODEL_SUBTYPES
21+
@eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M)
22+
end
23+
24+
StatisticalTraits.training_scitype(M::Type{<:Model}) = input_scitype(M)
25+
StatisticalTraits.training_scitype(::Type{<:Static}) = Tuple{}
26+
function StatisticalTraits.training_scitype(M::Type{<:Supervised})
27+
I = input_scitype(M)
28+
T = target_scitype(M)
29+
ret = Tuple{I,T}
30+
if supports_weights(M)
31+
W = AbstractVector{Union{Continuous,Count}} # weight scitype
32+
return Union{ret,Tuple{I,T,W}}
33+
elseif supports_class_weights(M)
34+
W = AbstractDict{Finite,Union{Continuous,Count}}
35+
return Union{ret,Tuple{I,T,W}}
36+
end
37+
return ret
38+
end
39+
40+
StatisticalTraits.transform_scitype(M::Type{<:Unsupervised}) =
41+
output_scitype(M)
42+
43+
StatisticalTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
44+
input_scitype(M)
45+
46+
StatisticalTraits.predict_scitype(M::Type{<:Deterministic}) = target_scitype(M)
47+
48+
49+
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` MODELS
50+
51+
# This seems less than ideal but should reduce the number of `Unknown`
52+
# in `prediction_type` for models which, historically, have not
53+
# implemented the trait.
54+
55+
StatisticalTraits.predict_scitype(M::Type{<:Probabilistic}) =
56+
_density(target_scitype(M))
57+
58+
_density(::Any) = Unknown
59+
for T in [:Continuous, :Count, :Textual]
60+
eval(quote
61+
_density(::Type{AbstractArray{$T,D}}) where D =
62+
AbstractArray{Density{$T},D}
63+
end)
64+
end
65+
66+
for T in [:Finite,
67+
:Multiclass,
68+
:OrderedFactor,
69+
:Infinite,
70+
:Continuous,
71+
:Count,
72+
:Textual]
73+
eval(quote
74+
_density(::Type{AbstractArray{<:$T,D}}) where D =
75+
AbstractArray{Density{<:$T},D}
76+
_density(::Type{Table($T)}) = Table(Density{$T})
77+
end)
78+
end
79+
80+
for T in [:Finite, :Multiclass, :OrderedFactor]
81+
eval(quote
82+
_density(::Type{AbstractArray{<:$T{N},D}}) where {N,D} =
83+
AbstractArray{Density{<:$T{N}},D}
84+
_density(::Type{AbstractArray{$T{N},D}}) where {N,D} =
85+
AbstractArray{Density{$T{N}},D}
86+
_density(::Type{Table($T{N})}) where N = Table(Density{$T{N}})
87+
end)
88+
end

test/model_traits.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,67 @@ import .Fruit
7979
@test docstring(Float64) == "Float64"
8080
@test docstring(Fruit.Banana) == "Banana"
8181
end
82+
83+
@testset "`_density` - helper for predict_scitype fallback" begin
84+
for T in [Continuous, Count, Textual]
85+
@test M._density(AbstractArray{T,3}) ==
86+
AbstractArray{Density{T},3}
87+
end
88+
89+
for T in [Finite,
90+
Multiclass,
91+
OrderedFactor,
92+
Infinite,
93+
Continuous,
94+
Count,
95+
Textual]
96+
@test M._density(AbstractVector{<:T}) ==
97+
AbstractVector{Density{<:T}}
98+
@test M._density(Table(T)) == Table(Density{T})
99+
end
100+
101+
for T in [Finite, Multiclass, OrderedFactor]
102+
@test M._density(AbstractArray{<:T{2},3}) ==
103+
AbstractArray{Density{<:T{2}},3}
104+
@test M._density(AbstractArray{T{2},3}) ==
105+
AbstractArray{Density{T{2}},3}
106+
@test M._density(Table(T{2})) == Table(Density{T{2}})
107+
end
108+
end
109+
110+
@mlj_model mutable struct P2 <: Probabilistic end
111+
M.target_scitype(::Type{<:P2}) = AbstractVector{<:Multiclass}
112+
M.input_scitype(::Type{<:P2}) = Table(Continuous)
113+
114+
@mlj_model mutable struct U2 <: Unsupervised end
115+
M.output_scitype(::Type{<:U2}) = AbstractVector{<:Multiclass}
116+
M.input_scitype(::Type{<:U2}) = Table(Continuous)
117+
118+
@mlj_model mutable struct S2 <: Static end
119+
M.output_scitype(::Type{<:S2}) = AbstractVector{<:Multiclass}
120+
M.input_scitype(::Type{<:S2}) = Table(Continuous)
121+
122+
@testset "operation scitypes" begin
123+
@test predict_scitype(P2()) == AbstractVector{Density{<:Multiclass}}
124+
@test transform_scitype(P2()) == Unknown
125+
@test transform_scitype(U2()) == AbstractVector{<:Multiclass}
126+
@test inverse_transform_scitype(U2()) == Table(Continuous)
127+
@test predict_scitype(U2()) == Unknown
128+
@test transform_scitype(S2()) == AbstractVector{<:Multiclass}
129+
@test inverse_transform_scitype(S2()) == Table(Continuous)
130+
end
131+
132+
@testset "abstract_type, training_scitype" begin
133+
@test abstract_type(P2()) == Probabilistic
134+
@test abstract_type(S1()) == Supervised
135+
@test abstract_type(U1()) == Unsupervised
136+
@test abstract_type(D1()) == Deterministic
137+
@test abstract_type(P1()) == Probabilistic
138+
139+
@test training_scitype(P2()) ==
140+
Tuple{Table(Continuous),AbstractVector{<:Multiclass}}
141+
@test training_scitype(U2()) == Table(Continuous)
142+
@test training_scitype(S2()) == Tuple{}
143+
end
144+
145+
true

0 commit comments

Comments
 (0)