@@ -33,41 +33,102 @@ function update_data end
33
33
MLJModelInterface.reformat(model, args...) -> data
34
34
35
35
Models optionally overload `reformat` to define transformations of
36
- user-supplied training data `args` into some model-specific
37
- representation `data` (e.g., from a table to a matrix). The fallback
38
- returns `args` (no transformation).
36
+ user-supplied data into some model-specific representation (e.g., from
37
+ a table to a matrix). Computational overheads associated with multiple
38
+ `fit!`/`predict`/`transform` calls are then avoided, when memory
39
+ resources allow. The fallback returns `args` (no transformation).
39
40
40
- If `mach` is a machine with `mach.model == model` then calling `fit!(mach)`
41
- will either call
41
+ Here "user-supplied data" is what the MLJ user supplies when
42
+ constructing a machine, as in `machine(models, args...)`, which
43
+ coincides with the arguments expected by `fit(model, verbosity,
44
+ args...)` when `reformat` is not overloaded.
42
45
43
- fit(model, verbosity, data...)
46
+ Implementing a `reformat` data front-end is permitted for any `Model`
47
+ subtype, except for subtypes of `Static`. Here is a complete list of
48
+ responsibilities for such an implementation, for some
49
+ `model::SomeModelType`:
44
50
45
- or
51
+ - A `reformat(model::SomeModelType, args...) -> data` method must be
52
+ implemented for each form of `args...` appearing in a valid machine
53
+ construction `machine(model, args...)` (there will be one for each
54
+ possible signature of `fit(::SomeModelType, ...)`).
46
55
47
- update(model, verbosity, data...)
56
+ - Additionally, if not included above, there must be a single argument
57
+ form of reformat, `reformat(model::SommeModelType, arg) -> (data,)`,
58
+ serving as a data front-end for operations like `predict`. It must
59
+ always hold that `reformat(model, args...)[1] = reformat(model,
60
+ args[1])`.
48
61
49
- where `data = reformat(model, mach.args...)`. This means that
50
- overloading `reformat` alters the form of the arguments expected by
51
- `fit` and `update`.
62
+ **Warning.** `reformat(model::SomeModelType, args...)` must always
63
+ return a tuple of the same length as `args`, even if this is one.
52
64
53
- If, instead, one calls `fit!(mach, rows=I)`, then `data` in the above
54
- `fit`/`update` calls is replaced with `selectrows(model, I,
55
- data...)`. So overloading `reformat` generally requires overloading of
56
- `selectrows` also, to specify how the model-specific representation of
57
- training data is resampled.
65
+ - `fit(model::SomeModelType, verbosity, data...)` should be
66
+ implemented as if `data` is the output of `reformat(model,
67
+ args...)`, where `args` is the data an MLJ user has bound to `model`
68
+ in some machine. The same applies to any overloading of `update`.
58
69
59
- Note `data` is always a tuple, but it needn't have the same length as
60
- `args`.
70
+ - Each implemented operation, such as `predict` and `transform` - but
71
+ excluding `inverse_transform` - must be defined as if its data
72
+ arguments are `reformat`ed versions of user-supplied data. For
73
+ example, in the supervised case, `data_new` in
74
+ `predict(model::SomeModelType, fitresult, data_new)` is
75
+ `reformat(model, Xnew)`, where `Xnew is the data provided by the MLJ
76
+ user in a call `predict(mach, Xnew)` (`mach.model == model`).
77
+
78
+ - To specify how the model-specific representation of data is to be
79
+ resampled, implement `selectrows(model::SomeModelType, I, data...)
80
+ -> resampled_data` for each overloading of `reformat(model::SomeModel,
81
+ args...) -> data` above. Here `I` is an arbitrary abstract integer
82
+ vector or `:` (type `Colon`).
83
+
84
+ **Warning.** `selectrows(model::SomeModelType, I, args...)` must always
85
+ return a tuple of the same length as `args`, even if this is one.
86
+
87
+ The fallback for `selectrows` is described at [`selectrows`](@ref).
88
+
89
+
90
+ ### Example
91
+
92
+ Suppose a supervised model type `SomeSupervised` supports sample
93
+ weights, leading to two different `fit` signatures:
94
+
95
+ fit(model::SomeSupervised, verbosity, X, y)
96
+ fit(model::SomeSupervised, verbosity, X, y, w)
97
+
98
+ predict(model::SomeSupervised, fitresult, Xnew)
99
+
100
+ Without a data front-end implemented, suppose `X` is expected to be a
101
+ table and `y` a vector, but suppose the core algorithm always converts
102
+ `X` to a matrix with features as rows (features corresponding to
103
+ columns in the table). Then a new data-front end might look like
104
+ this:
105
+
106
+ constant MMI = MLJModelInterface
107
+
108
+ # for fit:
109
+ MMI.reformat(::SomeSupervised, X, y) = (MMI.matrix(X, transpose=true), y)
110
+ MMI.reformat(::SomeSupervised, X, y, w) = (MMI.matrix(X, transpose=true), y, w)
111
+ MMI.selectrows(::SomeSupervised, I, Xmatrix, y) =
112
+ (view(Xmatrix, :, I), view(y, I))
113
+ MMI.selectrows(::SomeSupervised, I, Xmatrix, y, w) =
114
+ (view(Xmatrix, :, I), view(y, I), view(w, I))
115
+
116
+ # for predict:
117
+ MMI.reformat(::SomeSupervised, X) = (MMI.matrix(X, transpose=true),)
118
+ MMI.selectrows(::SomeSupervised, I, Xmatrix) = view(Xmatrix, I)
119
+
120
+ With these additions, `fit` and `predict` are refactored, so that `X`
121
+ and `Xnew` represent matrices with features as rows.
61
122
62
123
"""
63
124
reformat (model:: Model , args... ) = args
64
125
65
126
"""
66
127
selectrows(::Model, I, data...) -> sampled_data
67
128
68
- Models optionally overload `selectrows` for efficient resampling of
69
- training data. Here `data` is the ouput of calling `reformat` on
70
- user-provided data. The fallback assumes `data` is a tuple and calls
129
+ A model overloads `selectrows` whenever it buys into the optional
130
+ `reformat` front-end for data preprocessing. See [ `reformat`](@ref)
131
+ for details. The fallback assumes `data` is a tuple and calls
71
132
`selectrows(X, I)` for each `X` in `data`, returning the results in a
72
133
new tuple of the same length. This call makes sense when `X` is a
73
134
table, abstract vector or abstract matrix. In the last two cases, a
0 commit comments