Skip to content

Commit 05091c5

Browse files
authored
Merge pull request #114 from JuliaAI/dev
For a 1.3 release
2 parents 6288bac + d9a7c69 commit 05091c5

File tree

4 files changed

+113
-33
lines changed

4 files changed

+113
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.2.0"
4+
version = "1.3.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/MLJModelInterface.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module MLJModelInterface
22

3-
const MODEL_TRAITS = [
4-
:input_scitype,
3+
const MODEL_TRAITS =
4+
[:input_scitype,
55
:output_scitype,
66
:target_scitype,
77
:fit_data_scitype,
@@ -38,7 +38,17 @@ const ABSTRACT_MODEL_SUBTYPES =
3838
:Deterministic,
3939
:Interval,
4040
:JointProbabilistic,
41-
:Static]
41+
:Static,
42+
:Annotator,
43+
:SupervisedAnnotator,
44+
:UnsupervisedAnnotator,
45+
:SupervisedDetector,
46+
:UnsupervisedDetector,
47+
:ProbabilisticSupervisedDetector,
48+
:ProbabilisticUnsupervisedDetector,
49+
:DeterministicSupervisedDetector,
50+
:DeterministicUnsupervisedDetector]
51+
4252

4353
# ------------------------------------------------------------------------
4454
# Dependencies
@@ -69,7 +79,8 @@ export @mlj_model, metadata_pkg, metadata_model
6979
# model api
7080
export fit, update, update_data, transform, inverse_transform,
7181
fitted_params, predict, predict_mode, predict_mean, predict_median,
72-
predict_joint, evaluate, clean!, reformat, training_losses
82+
predict_joint, evaluate, clean!, reformat, training_losses,
83+
augmented_transform
7384

7485
# model traits
7586
for trait in MODEL_TRAITS
@@ -118,17 +129,30 @@ abstract type Model <: MLJType end
118129
# ------------------------------------------------------------------------
119130
# Model subtypes
120131

121-
abstract type Supervised <: Model end
122-
abstract type Unsupervised <: Model end
132+
abstract type Supervised <: Model end
133+
abstract type Unsupervised <: Model end
134+
abstract type Annotator <: Model end
123135

124136
abstract type Probabilistic <: Supervised end
125137
abstract type Deterministic <: Supervised end
126138
abstract type Interval <: Supervised end
127139

128-
abstract type Static <: Unsupervised end
129-
130140
abstract type JointProbabilistic <: Probabilistic end
131141

142+
abstract type Static <: Unsupervised end
143+
144+
abstract type SupervisedAnnotator <: Annotator end
145+
abstract type UnsupervisedAnnotator <: Annotator end
146+
147+
abstract type UnsupervisedDetector <: UnsupervisedAnnotator end
148+
abstract type SupervisedDetector <: SupervisedAnnotator end
149+
150+
abstract type ProbabilisticSupervisedDetector <: SupervisedDetector end
151+
abstract type ProbabilisticUnsupervisedDetector <: UnsupervisedDetector end
152+
153+
abstract type DeterministicSupervisedDetector <: SupervisedDetector end
154+
abstract type DeterministicUnsupervisedDetector <: UnsupervisedDetector end
155+
132156
# ------------------------------------------------------------------------
133157
# includes
134158

src/model_api.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
1515
# fallbacks for supervised models that don't support sample weights:
1616
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1717

18+
# fallback for unsupervised detectors when no "evaluation" labels appear:
19+
fit(m::Union{ProbabilisticUnsupervisedDetector,
20+
DeterministicUnsupervisedDetector},
21+
verbosity,
22+
X,
23+
y) = fit(m, verbosity, X)
24+
1825
"""
1926
MLJModelInterface.update(model, verbosity, fitresult, cache, data...)
2027
@@ -98,24 +105,33 @@ fitted_params(::Model, fitresult) = (fitresult=fitresult,)
98105
99106
predict(model, fitresult, new_data...)
100107
101-
`Supervised` models must implement the `predict` operation. Here
102-
`new_data` is the output of `reformat` called on user-specified data.
108+
`Supervised` and `SupervisedAnnotator` models must implement the
109+
`predict` operation. Here `new_data` is the output of `reformat`
110+
called on user-specified data.
103111
104112
"""
105113
function predict end
106114

107115
"""
108-
probabilistic supervised models may overload `predict_mean`
116+
117+
Models types `M` for which `prediction_type(M) == :probablisitic` may
118+
overload `predict_mean`.
119+
109120
"""
110121
function predict_mean end
111-
112122
"""
113-
probabilistic supervised models may overload `predict_mode`
123+
124+
Models types `M` for which `prediction_type(M) == :probablisitic` may
125+
overload `predict_mode`.
126+
114127
"""
115128
function predict_mode end
116129

117130
"""
118-
probabilistic supervised models may overload `predict_median`
131+
132+
Models types `M` for which `prediction_type(M) == :probablisitic` may
133+
overload `predict_median`.
134+
119135
"""
120136
function predict_median end
121137

@@ -127,12 +143,14 @@ function predict_median end
127143
function predict_joint end
128144

129145
"""
130-
unsupervised methods must implement the `transform` operation
146+
`Unsupervised` models must implement the `transform` operation.
131147
"""
132148
function transform end
133149

134150
"""
135-
unsupervised methods may implement the `inverse_transform` operation
151+
152+
`Unsupervised` models may implement the `inverse_transform` operation.
153+
136154
"""
137155
function inverse_transform end
138156

src/model_traits.jl

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
11
## OVERLOADING TRAIT DEFAULTS RELEVANT TO MODELS
22

3-
StatisticalTraits.docstring(M::Type{<:MLJType}) = name(M)
4-
StatisticalTraits.docstring(M::Type{<:Model}) =
3+
# unexported aliases:
4+
const Detector = Union{SupervisedDetector,UnsupervisedDetector}
5+
const ProbabilisticDetector = Union{ProbabilisticSupervisedDetector,
6+
ProbabilisticUnsupervisedDetector}
7+
const DeterministicDetector = Union{DeterministicSupervisedDetector,
8+
DeterministicUnsupervisedDetector}
9+
10+
const StatTraits = StatisticalTraits
11+
12+
StatTraits.docstring(M::Type{<:MLJType}) = name(M)
13+
StatTraits.docstring(M::Type{<:Model}) =
514
"$(name(M)) from $(package_name(M)).jl.\n" *
615
"[Documentation]($(package_url(M)))."
716

8-
StatisticalTraits.is_supervised(::Type{<:Supervised}) = true
9-
StatisticalTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
10-
StatisticalTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
11-
StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval
17+
StatTraits.is_supervised(::Type{<:Supervised}) = true
18+
19+
StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
20+
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
21+
StatTraits.prediction_type(::Type{<:Interval}) = :interval
22+
StatTraits.prediction_type(::Type{<:ProbabilisticDetector}) =
23+
:probabilistic
24+
StatTraits.prediction_type(::Type{<:DeterministicDetector}) =
25+
:deterministic
26+
27+
StatTraits.target_scitype(::Type{<:ProbabilisticDetector}) =
28+
AbstractVector{<:Union{Missing,OrderedFactor{2}}}
29+
StatTraits.target_scitype(::Type{<:DeterministicDetector}) =
30+
AbstractVector{<:Union{Missing,OrderedFactor{2}}}
1231

1332
# implementation is deferred as it requires methodswith which depends upon
1433
# InteractiveUtils which we don't want to bring here as a dependency
@@ -18,13 +37,13 @@ implemented_methods(model) = implemented_methods(typeof(model))
1837
implemented_methods(::LightInterface, M) = errlight("implemented_methods")
1938

2039
for M in ABSTRACT_MODEL_SUBTYPES
21-
@eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M)
40+
@eval(StatTraits.abstract_type(::Type{<:$M}) = $M)
2241
end
2342

24-
StatisticalTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
43+
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
2544
Tuple{input_scitype(M)}
26-
StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
27-
function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
45+
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
46+
function StatTraits.fit_data_scitype(M::Type{<:Supervised})
2847
I = input_scitype(M)
2948
T = target_scitype(M)
3049
ret = Tuple{I,T}
@@ -37,24 +56,42 @@ function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
3756
end
3857
return ret
3958
end
59+
StatTraits.fit_data_scitype(M::Type{<:UnsupervisedAnnotator}) =
60+
Tuple{input_scitype(M)}
61+
StatTraits.fit_data_scitype(M::Type{<:SupervisedAnnotator}) =
62+
Tuple{input_scitype(M),target_scitype(M)}
63+
64+
# In special case of `UnsupervisedProbabilisticDetector`, and
65+
# `UnsupervsedDeterministicDetector` we allow the target as an
66+
# optional argument to `fit` (that is ignored) so that the `machine`
67+
# constructor will accept it as a valid argument, which then enables
68+
# *evaluation* of the detector with labeled data:
69+
StatTraits.fit_data_scitype(M::Type{<:Union{
70+
ProbabilisticUnsupervisedDetector,
71+
DeterministicUnsupervisedDetector}}) =
72+
Union{Tuple{input_scitype(M)},
73+
Tuple{input_scitype(M),target_scitype(M)}}
4074

41-
StatisticalTraits.transform_scitype(M::Type{<:Unsupervised}) =
75+
StatTraits.transform_scitype(M::Type{<:Unsupervised}) =
4276
output_scitype(M)
4377

44-
StatisticalTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
78+
StatTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
4579
input_scitype(M)
4680

47-
StatisticalTraits.predict_scitype(M::Type{<:Deterministic}) = target_scitype(M)
81+
StatTraits.predict_scitype(M::Type{<:Union{
82+
Deterministic,DeterministicDetector}}) = target_scitype(M)
4883

4984

50-
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` MODELS
85+
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` and
86+
## `ProbabilisticDetector` MODELS
5187

5288
# This seems less than ideal but should reduce the number of `Unknown`
5389
# in `prediction_type` for models which, historically, have not
5490
# implemented the trait.
5591

56-
StatisticalTraits.predict_scitype(M::Type{<:Probabilistic}) =
57-
_density(target_scitype(M))
92+
StatTraits.predict_scitype(
93+
M::Type{<:Union{Probabilistic,ProbabilisticDetector}}
94+
) = _density(target_scitype(M))
5895

5996
_density(::Any) = Unknown
6097
for T in [:Continuous, :Count, :Textual]
@@ -78,6 +115,7 @@ for T in [:Finite,
78115
end)
79116
end
80117

118+
81119
for T in [:Finite, :Multiclass, :OrderedFactor]
82120
eval(quote
83121
_density(::Type{AbstractArray{<:$T{N},D}}) where {N,D} =

0 commit comments

Comments
 (0)