Skip to content

Commit aa7f923

Browse files
committed
document "supervised" transformers to close #203
1 parent e3d2571 commit aa7f923

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

docs/src/unsupervised_models.md

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@ similar fashion. The main differences are:
55

66
- The `fit` method, which still returns `(fitresult, cache, report)` will typically have
77
only one training argument `X`, as in `MLJModelInterface.fit(model, verbosity, X)`,
8-
although this is not a hard requirement. For example, a feature selection tool (wrapping
9-
some supervised model) might also include a target `y` as input. Furthermore, in the
10-
case of models that subtype `Static <: Unsupervised` (see [Static
11-
models](@ref)) `fit` has no training arguments at all, but does not need to be
12-
implemented as a fallback returns `(nothing, nothing, nothing)`.
8+
although this is not a hard requirement; see [Transformers requiring a target variable
9+
in training](@ref) below. Furthermore, in the case of models that subtype `Static <:
10+
Unsupervised` (see [Static models](@ref)) `fit` has no training arguments at all, but
11+
does not need to be implemented as a fallback returns `(nothing, nothing, nothing)`.
1312

1413
- A `transform` and/or `predict` method is implemented, and has the same signature as
1514
`predict` does in the supervised case, as in `MLJModelInterface.transform(model,
@@ -27,15 +26,43 @@ similar fashion. The main differences are:
2726
argument, you must overload the trait `fit_data_scitype`, which bounds the allowed
2827
`data` passed to `fit(model, verbosity, data...)` and will always be a `Tuple` type.
2928

30-
- An `inverse_transform` can be optionally implemented. The signature
31-
is the same as `transform`, as in
32-
`MLJModelInterface.inverse_transform(model, fitresult, Xout)`, which:
29+
- An `inverse_transform` can be optionally implemented. The signature is the same as
30+
`transform`, as in `MLJModelInterface.inverse_transform(model::MyUnsupervisedModel,
31+
fitresult, Xout)`, which:
3332
- must make sense for any `Xout` for which `scitype(Xout) <:
34-
output_scitype(SomeSupervisedModel)` (see below); and
33+
output_scitype(MyUnsupervisedModel)`; and
3534
- must return an object `Xin` satisfying `scitype(Xin) <:
36-
input_scitype(SomeSupervisedModel)`.
35+
input_scitype(MyUnsupervisedModel)`.
3736

38-
For sample implementatations, see MLJ's [built-in
37+
For sample implementations, see MLJ's [built-in
3938
transformers](https://github.com/JuliaAI/MLJModels.jl/blob/dev/src/builtins/Transformers.jl)
4039
and the clustering models at
4140
[MLJClusteringInterface.jl](https://github.com/jbrea/MLJClusteringInterface.jl).
41+
42+
## Transformers requiring a target variable in training
43+
44+
An `Unsupervised` model that is not `Static` may include a second argument `y` in it's
45+
`fit` signature, as in `fit(::MyTransformer, verbosity, X, y)`. For example, some feature
46+
selection tools require a target variable `y` in training. (Unlike `Supervised` models, an
47+
`Unsupervised` model is not required to implement `predict`, and in pipelines it is the
48+
output of `transform`, and not `predict`, that is always propagated to the next model.) Such a
49+
model should overload the trait `target_in_fit`, as in this example:
50+
51+
```julia
52+
MLJModelInterface.target_in_fit(::Type{<:MyTransformer}) = true
53+
```
54+
55+
This ensures that such models can appear in pipelines, and that a target provided to the
56+
pipeline model is passed on to the model in training.
57+
58+
If the model implements more than one `fit` signature (e.g., one with a target `y` and one
59+
without) then `fit_data_scitype` must also be overloaded, as in this example:
60+
61+
```julia
62+
MLJModelInterface.fit_data_scitype(::Type{<:MyTransformer}) = Union{
63+
Tuple{Table(Continuous)},
64+
Tuple{Table(Continous), AbstractVector{<:Finite}},
65+
}
66+
```
67+
68+

0 commit comments

Comments
 (0)