Skip to content

Commit 6288bac

Browse files
authored
Merge pull request #110 from JuliaAI/dev
For a 1.2 release
2 parents 9a7b6ba + fc7191c commit 6288bac

File tree

4 files changed

+143
-3
lines changed

4 files changed

+143
-3
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.1.3"
4+
version = "1.2.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
99
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
1010

1111
[compat]
12-
ScientificTypesBase = "1, 2"
13-
StatisticalTraits = "2"
12+
ScientificTypesBase = "2.1"
13+
StatisticalTraits = "2.1"
1414
julia = "1"
1515

1616
[extras]

src/MLJModelInterface.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ const MODEL_TRAITS = [
44
:input_scitype,
55
:output_scitype,
66
:target_scitype,
7+
:fit_data_scitype,
8+
:predict_scitype,
9+
:transform_scitype,
10+
:inverse_transform_scitype,
711
:is_pure_julia,
812
:package_name,
913
:package_license,
@@ -18,6 +22,7 @@ const MODEL_TRAITS = [
1822
:name,
1923
:is_supervised,
2024
:prediction_type,
25+
:abstract_type,
2126
:implemented_methods,
2227
:hyperparameters,
2328
:hyperparameter_types,

src/model_traits.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,74 @@ StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval
1616
implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M)
1717
implemented_methods(model) = implemented_methods(typeof(model))
1818
implemented_methods(::LightInterface, M) = errlight("implemented_methods")
19+
20+
for M in ABSTRACT_MODEL_SUBTYPES
21+
@eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M)
22+
end
23+
24+
StatisticalTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
25+
Tuple{input_scitype(M)}
26+
StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
27+
function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
28+
I = input_scitype(M)
29+
T = target_scitype(M)
30+
ret = Tuple{I,T}
31+
if supports_weights(M)
32+
W = AbstractVector{Union{Continuous,Count}} # weight scitype
33+
return Union{ret,Tuple{I,T,W}}
34+
elseif supports_class_weights(M)
35+
W = AbstractDict{Finite,Union{Continuous,Count}}
36+
return Union{ret,Tuple{I,T,W}}
37+
end
38+
return ret
39+
end
40+
41+
StatisticalTraits.transform_scitype(M::Type{<:Unsupervised}) =
42+
output_scitype(M)
43+
44+
StatisticalTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
45+
input_scitype(M)
46+
47+
StatisticalTraits.predict_scitype(M::Type{<:Deterministic}) = target_scitype(M)
48+
49+
50+
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` MODELS
51+
52+
# This seems less than ideal but should reduce the number of `Unknown`
53+
# in `prediction_type` for models which, historically, have not
54+
# implemented the trait.
55+
56+
StatisticalTraits.predict_scitype(M::Type{<:Probabilistic}) =
57+
_density(target_scitype(M))
58+
59+
_density(::Any) = Unknown
60+
for T in [:Continuous, :Count, :Textual]
61+
eval(quote
62+
_density(::Type{AbstractArray{$T,D}}) where D =
63+
AbstractArray{Density{$T},D}
64+
end)
65+
end
66+
67+
for T in [:Finite,
68+
:Multiclass,
69+
:OrderedFactor,
70+
:Infinite,
71+
:Continuous,
72+
:Count,
73+
:Textual]
74+
eval(quote
75+
_density(::Type{AbstractArray{<:$T,D}}) where D =
76+
AbstractArray{Density{<:$T},D}
77+
_density(::Type{Table($T)}) = Table(Density{$T})
78+
end)
79+
end
80+
81+
for T in [:Finite, :Multiclass, :OrderedFactor]
82+
eval(quote
83+
_density(::Type{AbstractArray{<:$T{N},D}}) where {N,D} =
84+
AbstractArray{Density{<:$T{N}},D}
85+
_density(::Type{AbstractArray{$T{N},D}}) where {N,D} =
86+
AbstractArray{Density{$T{N}},D}
87+
_density(::Type{Table($T{N})}) where N = Table(Density{$T{N}})
88+
end)
89+
end

test/model_traits.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,67 @@ import .Fruit
7979
@test docstring(Float64) == "Float64"
8080
@test docstring(Fruit.Banana) == "Banana"
8181
end
82+
83+
@testset "`_density` - helper for predict_scitype fallback" begin
84+
for T in [Continuous, Count, Textual]
85+
@test M._density(AbstractArray{T,3}) ==
86+
AbstractArray{Density{T},3}
87+
end
88+
89+
for T in [Finite,
90+
Multiclass,
91+
OrderedFactor,
92+
Infinite,
93+
Continuous,
94+
Count,
95+
Textual]
96+
@test M._density(AbstractVector{<:T}) ==
97+
AbstractVector{Density{<:T}}
98+
@test M._density(Table(T)) == Table(Density{T})
99+
end
100+
101+
for T in [Finite, Multiclass, OrderedFactor]
102+
@test M._density(AbstractArray{<:T{2},3}) ==
103+
AbstractArray{Density{<:T{2}},3}
104+
@test M._density(AbstractArray{T{2},3}) ==
105+
AbstractArray{Density{T{2}},3}
106+
@test M._density(Table(T{2})) == Table(Density{T{2}})
107+
end
108+
end
109+
110+
@mlj_model mutable struct P2 <: Probabilistic end
111+
M.target_scitype(::Type{<:P2}) = AbstractVector{<:Multiclass}
112+
M.input_scitype(::Type{<:P2}) = Table(Continuous)
113+
114+
@mlj_model mutable struct U2 <: Unsupervised end
115+
M.output_scitype(::Type{<:U2}) = AbstractVector{<:Multiclass}
116+
M.input_scitype(::Type{<:U2}) = Table(Continuous)
117+
118+
@mlj_model mutable struct S2 <: Static end
119+
M.output_scitype(::Type{<:S2}) = AbstractVector{<:Multiclass}
120+
M.input_scitype(::Type{<:S2}) = Table(Continuous)
121+
122+
@testset "operation scitypes" begin
123+
@test predict_scitype(P2()) == AbstractVector{Density{<:Multiclass}}
124+
@test transform_scitype(P2()) == Unknown
125+
@test transform_scitype(U2()) == AbstractVector{<:Multiclass}
126+
@test inverse_transform_scitype(U2()) == Table(Continuous)
127+
@test predict_scitype(U2()) == Unknown
128+
@test transform_scitype(S2()) == AbstractVector{<:Multiclass}
129+
@test inverse_transform_scitype(S2()) == Table(Continuous)
130+
end
131+
132+
@testset "abstract_type, fit_data_scitype" begin
133+
@test abstract_type(P2()) == Probabilistic
134+
@test abstract_type(S1()) == Supervised
135+
@test abstract_type(U1()) == Unsupervised
136+
@test abstract_type(D1()) == Deterministic
137+
@test abstract_type(P1()) == Probabilistic
138+
139+
@test fit_data_scitype(P2()) ==
140+
Tuple{Table(Continuous),AbstractVector{<:Multiclass}}
141+
@test fit_data_scitype(U2()) == Tuple{Table(Continuous)}
142+
@test fit_data_scitype(S2()) == Tuple{}
143+
end
144+
145+
true

0 commit comments

Comments
 (0)