Skip to content

Commit 65af839

Browse files
committed
update doc-strings
more more
1 parent 5ef9346 commit 65af839

File tree

1 file changed

+82
-21
lines changed

1 file changed

+82
-21
lines changed

src/model_api.jl

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,41 +33,102 @@ function update_data end
3333
MLJModelInterface.reformat(model, args...) -> data
3434
3535
Models optionally overload `reformat` to define transformations of
36-
user-supplied training data `args` into some model-specific
37-
representation `data` (e.g., from a table to a matrix). The fallback
38-
returns `args` (no transformation).
36+
user-supplied data into some model-specific representation (e.g., from
37+
a table to a matrix). Computational overheads associated with multiple
38+
`fit!`/`predict`/`transform` calls are then avoided, when memory
39+
resources allow. The fallback returns `args` (no transformation).
3940
40-
If `mach` is a machine with `mach.model == model` then calling `fit!(mach)`
41-
will either call
41+
Here "user-supplied data" is what the MLJ user supplies when
42+
constructing a machine, as in `machine(models, args...)`, which
43+
coincides with the arguments expected by `fit(model, verbosity,
44+
args...)` when `reformat` is not overloaded.
4245
43-
fit(model, verbosity, data...)
46+
Implementing a `reformat` data front-end is permitted for any `Model`
47+
subtype, except for subtypes of `Static`. Here is a complete list of
48+
responsibilities for such an implementation, for some
49+
`model::SomeModelType`:
4450
45-
or
51+
- A `reformat(model::SomeModelType, args...) -> data` method must be
52+
implemented for each form of `args...` appearing in a valid machine
53+
construction `machine(model, args...)` (there will be one for each
54+
possible signature of `fit(::SomeModelType, ...)`).
4655
47-
update(model, verbosity, data...)
56+
- Additionally, if not included above, there must be a single argument
57+
form of reformat, `reformat(model::SommeModelType, arg) -> (data,)`,
58+
serving as a data front-end for operations like `predict`. It must
59+
always hold that `reformat(model, args...)[1] = reformat(model,
60+
args[1])`.
4861
49-
where `data = reformat(model, mach.args...)`. This means that
50-
overloading `reformat` alters the form of the arguments expected by
51-
`fit` and `update`.
62+
**Warning.** `reformat(model::SomeModelType, args...)` must always
63+
return a tuple of the same length as `args`, even if this is one.
5264
53-
If, instead, one calls `fit!(mach, rows=I)`, then `data` in the above
54-
`fit`/`update` calls is replaced with `selectrows(model, I,
55-
data...)`. So overloading `reformat` generally requires overloading of
56-
`selectrows` also, to specify how the model-specific representation of
57-
training data is resampled.
65+
- `fit(model::SomeModelType, verbosity, data...)` should be
66+
implemented as if `data` is the output of `reformat(model,
67+
args...)`, where `args` is the data an MLJ user has bound to `model`
68+
in some machine. The same applies to any overloading of `update`.
5869
59-
Note `data` is always a tuple, but it needn't have the same length as
60-
`args`.
70+
- Each implemented operation, such as `predict` and `transform` - but
71+
excluding `inverse_transform` - must be defined as if its data
72+
arguments are `reformat`ed versions of user-supplied data. For
73+
example, in the supervised case, `data_new` in
74+
`predict(model::SomeModelType, fitresult, data_new)` is
75+
`reformat(model, Xnew)`, where `Xnew is the data provided by the MLJ
76+
user in a call `predict(mach, Xnew)` (`mach.model == model`).
77+
78+
- To specify how the model-specific representation of data is to be
79+
resampled, implement `selectrows(model::SomeModelType, I, data...)
80+
-> resampled_data` for each overloading of `reformat(model::SomeModel,
81+
args...) -> data` above. Here `I` is an arbitrary abstract integer
82+
vector or `:` (type `Colon`).
83+
84+
**Warning.** `selectrows(model::SomeModelType, I, args...)` must always
85+
return a tuple of the same length as `args`, even if this is one.
86+
87+
The fallback for `selectrows` is described at [`selectrows`](@ref).
88+
89+
90+
### Example
91+
92+
Suppose a supervised model type `SomeSupervised` supports sample
93+
weights, leading to two different `fit` signatures:
94+
95+
fit(model::SomeSupervised, verbosity, X, y)
96+
fit(model::SomeSupervised, verbosity, X, y, w)
97+
98+
predict(model::SomeSupervised, fitresult, Xnew)
99+
100+
Without a data front-end implemented, suppose `X` is expected to be a
101+
table and `y` a vector, but suppose the core algorithm always converts
102+
`X` to a matrix with features as rows (features corresponding to
103+
columns in the table). Then a new data-front end might look like
104+
this:
105+
106+
constant MMI = MLJModelInterface
107+
108+
# for fit:
109+
MMI.reformat(::SomeSupervised, X, y) = (MMI.matrix(X, transpose=true), y)
110+
MMI.reformat(::SomeSupervised, X, y, w) = (MMI.matrix(X, transpose=true), y, w)
111+
MMI.selectrows(::SomeSupervised, I, Xmatrix, y) =
112+
(view(Xmatrix, :, I), view(y, I))
113+
MMI.selectrows(::SomeSupervised, I, Xmatrix, y, w) =
114+
(view(Xmatrix, :, I), view(y, I), view(w, I))
115+
116+
# for predict:
117+
MMI.reformat(::SomeSupervised, X) = (MMI.matrix(X, transpose=true),)
118+
MMI.selectrows(::SomeSupervised, I, Xmatrix) = view(Xmatrix, I)
119+
120+
With these additions, `fit` and `predict` are refactored, so that `X`
121+
and `Xnew` represent matrices with features as rows.
61122
62123
"""
63124
reformat(model::Model, args...) = args
64125

65126
"""
66127
selectrows(::Model, I, data...) -> sampled_data
67128
68-
Models optionally overload `selectrows` for efficient resampling of
69-
training data. Here `data` is the ouput of calling `reformat` on
70-
user-provided data. The fallback assumes `data` is a tuple and calls
129+
A model overloads `selectrows` whenever it buys into the optional
130+
`reformat` front-end for data preprocessing. See [`reformat`](@ref)
131+
for details. The fallback assumes `data` is a tuple and calls
71132
`selectrows(X, I)` for each `X` in `data`, returning the results in a
72133
new tuple of the same length. This call makes sense when `X` is a
73134
table, abstract vector or abstract matrix. In the last two cases, a

0 commit comments

Comments
 (0)