@@ -5,11 +5,10 @@ similar fashion. The main differences are:
55
66- The ` fit ` method, which still returns ` (fitresult, cache, report) ` will typically have
77 only one training argument ` X ` , as in ` MLJModelInterface.fit(model, verbosity, X) ` ,
8- although this is not a hard requirement. For example, a feature selection tool (wrapping
9- some supervised model) might also include a target ` y ` as input. Furthermore, in the
10- case of models that subtype ` Static <: Unsupervised ` (see [ Static
11- models] ( @ref ) ) ` fit ` has no training arguments at all, but does not need to be
12- implemented as a fallback returns ` (nothing, nothing, nothing) ` .
8+ although this is not a hard requirement; see [ Transformers requiring a target variable
9+ in training] ( @ref ) below. Furthermore, in the case of models that subtype `Static <:
10+ Unsupervised` (see [Static models](@ref)) ` fit` has no training arguments at all, but
11+ does not need to be implemented as a fallback returns ` (nothing, nothing, nothing) ` .
1312
1413- A ` transform ` and/or ` predict ` method is implemented, and has the same signature as
1514 ` predict ` does in the supervised case, as in `MLJModelInterface.transform(model,
@@ -27,15 +26,43 @@ similar fashion. The main differences are:
2726 argument, you must overload the trait ` fit_data_scitype ` , which bounds the allowed
2827 ` data ` passed to ` fit(model, verbosity, data...) ` and will always be a ` Tuple ` type.
2928
30- - An ` inverse_transform ` can be optionally implemented. The signature
31- is the same as ` transform ` , as in
32- ` MLJModelInterface.inverse_transform(model, fitresult, Xout)` , which:
29+ - An ` inverse_transform ` can be optionally implemented. The signature is the same as
30+ ` transform ` , as in `MLJModelInterface.inverse_transform(model::MyUnsupervisedModel,
31+ fitresult, Xout)`, which:
3332 - must make sense for any ` Xout ` for which `scitype(Xout) <:
34- output_scitype(SomeSupervisedModel)` (see below) ; and
33+ output_scitype(MyUnsupervisedModel)` ; and
3534 - must return an object ` Xin ` satisfying `scitype(Xin) <:
36- input_scitype(SomeSupervisedModel )`.
35+ input_scitype(MyUnsupervisedModel )`.
3736
38- For sample implementatations , see MLJ's [ built-in
37+ For sample implementations , see MLJ's [ built-in
3938transformers] ( https://github.com/JuliaAI/MLJModels.jl/blob/dev/src/builtins/Transformers.jl )
4039and the clustering models at
4140[ MLJClusteringInterface.jl] ( https://github.com/jbrea/MLJClusteringInterface.jl ) .
41+
42+ ## Transformers requiring a target variable in training
43+
44+ An ` Unsupervised ` model that is not ` Static ` may include a second argument ` y ` in it's
45+ ` fit ` signature, as in ` fit(::MyTransformer, verbosity, X, y) ` . For example, some feature
46+ selection tools require a target variable ` y ` in training. (Unlike ` Supervised ` models, an
47+ ` Unsupervised ` model is not required to implement ` predict ` , and in pipelines it is the
48+ output of ` transform ` , and not ` predict ` , that is always propagated to the next model.) Such a
49+ model should overload the trait ` target_in_fit ` , as in this example:
50+
51+ ``` julia
52+ MLJModelInterface. target_in_fit (:: Type{<:MyTransformer} ) = true
53+ ```
54+
55+ This ensures that such models can appear in pipelines, and that a target provided to the
56+ pipeline model is passed on to the model in training.
57+
58+ If the model implements more than one ` fit ` signature (e.g., one with a target ` y ` and one
59+ without) then ` fit_data_scitype ` must also be overloaded, as in this example:
60+
61+ ``` julia
62+ MLJModelInterface. fit_data_scitype (:: Type{<:MyTransformer} ) = Union{
63+ Tuple{Table (Continuous)},
64+ Tuple{Table (Continous), AbstractVector{<: Finite }},
65+ }
66+ ```
67+
68+
0 commit comments