Skip to content

Commit 5ecb557

Browse files
committed
add Annotator and Detector types
bump 1.1.3 add UnsupervisedAnnotator <: Unsupervised add Detector <: Probabilistic overload target_scitype for Detector to be OrderedFactor{2} add augmented_transform stub export new types/methods list all model subtypes in a constant for easier MLJBase extension re-arrange oops typo oops oops oops tweaks add forgotten types to list
1 parent fc7191c commit 5ecb557

File tree

4 files changed

+136
-32
lines changed

4 files changed

+136
-32
lines changed

src/MLJModelInterface.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module MLJModelInterface
22

3-
const MODEL_TRAITS = [
4-
:input_scitype,
3+
const MODEL_TRAITS =
4+
[:input_scitype,
55
:output_scitype,
66
:target_scitype,
77
:fit_data_scitype,
@@ -38,7 +38,17 @@ const ABSTRACT_MODEL_SUBTYPES =
3838
:Deterministic,
3939
:Interval,
4040
:JointProbabilistic,
41-
:Static]
41+
:Static,
42+
:Annotator,
43+
:SupervisedAnnotator,
44+
:UnsupervisedAnnotator,
45+
:SupervisedDetector,
46+
:UnsupervisedDetector,
47+
:AbstractProbabilisticSupervisedDetector,
48+
:AbstractProbabilisticUnsupervisedDetector,
49+
:AbstractDeterministicSupervisedDetector,
50+
:AbstractDeterministicUnsupervisedDetector]
51+
4252

4353
# ------------------------------------------------------------------------
4454
# Dependencies
@@ -69,7 +79,8 @@ export @mlj_model, metadata_pkg, metadata_model
6979
# model api
7080
export fit, update, update_data, transform, inverse_transform,
7181
fitted_params, predict, predict_mode, predict_mean, predict_median,
72-
predict_joint, evaluate, clean!, reformat, training_losses
82+
predict_joint, evaluate, clean!, reformat, training_losses,
83+
augmented_predict
7384

7485
# model traits
7586
for trait in MODEL_TRAITS
@@ -118,17 +129,30 @@ abstract type Model <: MLJType end
118129
# ------------------------------------------------------------------------
119130
# Model subtypes
120131

121-
abstract type Supervised <: Model end
122-
abstract type Unsupervised <: Model end
132+
abstract type Supervised <: Model end
133+
abstract type Unsupervised <: Model end
134+
abstract type Annotator <: Model end
123135

124136
abstract type Probabilistic <: Supervised end
125137
abstract type Deterministic <: Supervised end
126138
abstract type Interval <: Supervised end
127139

128-
abstract type Static <: Unsupervised end
129-
130140
abstract type JointProbabilistic <: Probabilistic end
131141

142+
abstract type Static <: Unsupervised end
143+
144+
abstract type SupervisedAnnotator <: Annotator end
145+
abstract type UnsupervisedAnnotator <: Annotator end
146+
147+
abstract type UnsupervisedDetector <: UnsupervisedAnnotator end
148+
abstract type SupervisedDetector <: SupervisedAnnotator end
149+
150+
abstract type AbstractProbabilisticSupervisedDetector <: SupervisedDetector end
151+
abstract type AbstractProbabilisticUnsupervisedDetector <: UnsupervisedDetector end
152+
153+
abstract type AbstractDeterministicSupervisedDetector <: SupervisedDetector end
154+
abstract type AbstractDeterministicUnsupervisedDetector <: UnsupervisedDetector end
155+
132156
# ------------------------------------------------------------------------
133157
# includes
134158

src/data_utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,9 @@ UnivariateFinite(probs; kwargs...) =
434434
UnivariateFinite(get_interface_mode(), probs; kwargs...)
435435
UnivariateFinite(::LightInterface, a...; kwargs...) =
436436
errlight("UnivariateFinite")
437+
438+
## FOR DETECTION MODELS
439+
440+
const OUTLIER = "outlier"
441+
const INLIER = "inlier"
442+
const UNKNOWN = "unknown"

src/model_api.jl

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
1515
# fallbacks for supervised models that don't support sample weights:
1616
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1717

18+
# fallback for unsupervised detectors when no "evaluation" labels appear:
19+
fit(m::Probabilistic, verbosity, X, y) = fit(m, verbosity, X)
20+
1821
"""
1922
MLJModelInterface.update(model, verbosity, fitresult, cache, data...)
2023
@@ -98,24 +101,54 @@ fitted_params(::Model, fitresult) = (fitresult=fitresult,)
98101
99102
predict(model, fitresult, new_data...)
100103
101-
`Supervised` models must implement the `predict` operation. Here
102-
`new_data` is the output of `reformat` called on user-specified data.
104+
`Supervised` and `SupervisedAnnotator` models must implement the
105+
`predict` operation. Here `new_data` is the output of `reformat`
106+
called on user-specified data.
103107
104108
"""
105109
function predict end
106110

107111
"""
108-
probabilistic supervised models may overload `predict_mean`
112+
augmented_predict
113+
114+
If implemented, the same as `predict`, but with a return value
115+
augmented by the `predict`ion of the training data.
116+
117+
For example, if implemented for a `Supervised` model with a
118+
`predict` method, `augmented_predict(model, fitresult, Xnew)` will
119+
return
120+
121+
```julia
122+
(predict(model, fitresult, X), predict(model, fitresult, Xnew))
123+
```
124+
125+
where `(X, y)` was the training data.
126+
127+
Must be implemented by any `UnsupervisedDetector` or `SupervisedDetector`.
128+
129+
"""
130+
function augmented_predict end
131+
109132
"""
110-
function predict_mean end
111133
134+
Models types `M` for which `prediction_type(M) == :probablisitic` may
135+
overload `predict_mean`.
136+
137+
"""
138+
function predict_mean end
112139
"""
113-
probabilistic supervised models may overload `predict_mode`
140+
141+
Models types `M` for which `prediction_type(M) == :probablisitic` may
142+
overload `predict_mode`.
143+
114144
"""
115145
function predict_mode end
116146

117147
"""
118-
probabilistic supervised models may overload `predict_median`
148+
149+
Models types `M` for which `prediction_type(M) == :probablisitic` may
150+
overload `predict_median`.
151+
119152
"""
120153
function predict_median end
121154

@@ -127,12 +160,14 @@ function predict_median end
127160
function predict_joint end
128161

129162
"""
130-
unsupervised methods must implement the `transform` operation
163+
`Unsupervised` models must implement the `transform` operation.
131164
"""
132165
function transform end
133166

134167
"""
135-
unsupervised methods may implement the `inverse_transform` operation
168+
169+
`Unsupervised` models may implement the `inverse_transform` operation.
170+
136171
"""
137172
function inverse_transform end
138173

@@ -145,3 +180,4 @@ function restore end
145180
some meta-models may choose to implement the `evaluate` operations
146181
"""
147182
function evaluate end
183+

src/model_traits.jl

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
11
## OVERLOADING TRAIT DEFAULTS RELEVANT TO MODELS
22

3-
StatisticalTraits.docstring(M::Type{<:MLJType}) = name(M)
4-
StatisticalTraits.docstring(M::Type{<:Model}) =
3+
# unexported aliases:
4+
const Detector = Union{SupervisedDetector,UnsupervisedDetector}
5+
const ProbabilisticDetector = Union{AbstractProbabilisticSupervisedDetector,
6+
AbstractProbabilisticUnsupervisedDetector}
7+
const DeterministicDetector = Union{AbstractDeterministicSupervisedDetector,
8+
AbstractDeterministicUnsupervisedDetector}
9+
10+
const StatTraits = StatisticalTraits
11+
12+
StatTraits.docstring(M::Type{<:MLJType}) = name(M)
13+
StatTraits.docstring(M::Type{<:Model}) =
514
"$(name(M)) from $(package_name(M)).jl.\n" *
615
"[Documentation]($(package_url(M)))."
716

8-
StatisticalTraits.is_supervised(::Type{<:Supervised}) = true
9-
StatisticalTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
10-
StatisticalTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
11-
StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval
17+
StatTraits.is_supervised(::Type{<:Supervised}) = true
18+
19+
StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
20+
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
21+
StatTraits.prediction_type(::Type{<:Interval}) = :interval
22+
StatTraits.prediction_type(::Type{<:ProbabilisticDetector}) =
23+
:probabilistic
24+
StatTraits.prediction_type(::Type{<:DeterministicDetector}) =
25+
:deterministic
26+
27+
StatTraits.target_scitype(::Type{<:ProbabilisticDetector}) =
28+
AbstractVector{OrderedFactor{2}}
29+
StatTraits.target_scitype(::Type{<:DeterministicDetector}) =
30+
AbstractVector{OrderedFactor{2}}
1231

1332
# implementation is deferred as it requires methodswith which depends upon
1433
# InteractiveUtils which we don't want to bring here as a dependency
@@ -18,13 +37,13 @@ implemented_methods(model) = implemented_methods(typeof(model))
1837
implemented_methods(::LightInterface, M) = errlight("implemented_methods")
1938

2039
for M in ABSTRACT_MODEL_SUBTYPES
21-
@eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M)
40+
@eval(StatTraits.abstract_type(::Type{<:$M}) = $M)
2241
end
2342

24-
StatisticalTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
43+
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
2544
Tuple{input_scitype(M)}
26-
StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
27-
function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
45+
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
46+
function StatTraits.fit_data_scitype(M::Type{<:Supervised})
2847
I = input_scitype(M)
2948
T = target_scitype(M)
3049
ret = Tuple{I,T}
@@ -37,24 +56,42 @@ function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
3756
end
3857
return ret
3958
end
59+
StatTraits.fit_data_scitype(M::Type{<:UnsupervisedAnnotator}) =
60+
Tuple{input_scitype(M)}
61+
StatTraits.fit_data_scitype(M::Type{<:SupervisedAnnotator}) =
62+
Tuple{input_scitype(M),target_scitype(M)}
63+
64+
# In special case of `UnsupervisedProbabilisticDetector`, and
65+
# `UnsupervsedDeterministicDetector` we allow the target as an
66+
# optional argument to `fit` (that is ignored) so that the `machine`
67+
# constructor will accept it as a valid argument, which then enables
68+
# *evaluation* of the detector with labeled data:
69+
StatTraits.fit_data_scitype(M::Type{<:Union{
70+
AbstractProbabilisticUnsupervisedDetector,
71+
AbstractDeterministicUnsupervisedDetector}}) =
72+
Union{Tuple{input_scitype(M)},
73+
Tuple{input_scitype(M),target_scitype(M)}}
4074

41-
StatisticalTraits.transform_scitype(M::Type{<:Unsupervised}) =
75+
StatTraits.transform_scitype(M::Type{<:Unsupervised}) =
4276
output_scitype(M)
4377

44-
StatisticalTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
78+
StatTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
4579
input_scitype(M)
4680

47-
StatisticalTraits.predict_scitype(M::Type{<:Deterministic}) = target_scitype(M)
81+
StatTraits.predict_scitype(M::Type{<:Union{
82+
Deterministic,DeterministicDetector}}) = target_scitype(M)
4883

4984

50-
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` MODELS
85+
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` and
86+
## `ProbabilisticDetector` MODELS
5187

5288
# This seems less than ideal but should reduce the number of `Unknown`
5389
# in `prediction_type` for models which, historically, have not
5490
# implemented the trait.
5591

56-
StatisticalTraits.predict_scitype(M::Type{<:Probabilistic}) =
57-
_density(target_scitype(M))
92+
StatTraits.predict_scitype(
93+
M::Type{<:Union{Probabilistic,ProbabilisticDetector}}
94+
) = _density(target_scitype(M))
5895

5996
_density(::Any) = Unknown
6097
for T in [:Continuous, :Count, :Textual]
@@ -78,6 +115,7 @@ for T in [:Finite,
78115
end)
79116
end
80117

118+
81119
for T in [:Finite, :Multiclass, :OrderedFactor]
82120
eval(quote
83121
_density(::Type{AbstractArray{<:$T{N},D}}) where {N,D} =

0 commit comments

Comments
 (0)