|
1 | 1 | """
|
2 |
| -every model interface must implement a `fit` method of the form |
3 |
| -`fit(model, verb::Integer, training_args...) -> fitresult, cache, report` |
| 2 | + fit(model, verbosity, data...) -> fitresult, cache, report |
| 3 | +
|
| 4 | +All models must implement a `fit` method. Here `data` is the |
| 5 | +output of `reformat` on user-provided data, or some some resampling |
| 6 | +thereof. The fallback of `reformat` returns the user-provided data |
| 7 | +(eg, a table). |
| 8 | +
|
4 | 9 | """
|
5 | 10 | function fit end
|
6 | 11 |
|
7 | 12 | # fallback for static transformations
|
8 |
| -fit(::Static, ::Integer, a...) = (nothing, nothing, nothing) |
| 13 | +fit(::Static, ::Integer, data...) = (nothing, nothing, nothing) |
9 | 14 |
|
10 | 15 | # fallbacks for supervised models that don't support sample weights:
|
11 |
| -fit(m::Supervised, verb::Integer, X, y, w) = fit(m, verb, X, y) |
| 16 | +fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y) |
12 | 17 |
|
13 |
| -# this operation can be optionally overloaded to provide access to |
14 |
| -# fitted parameters (eg, coeficients of linear model): |
15 |
| -fitted_params(::Model, fitres) = (fitresult=fitres,) |
| 18 | +""" |
| 19 | + update(model, verbosity, fitresult, cache, data...) |
| 20 | +
|
| 21 | +Models may optionally implement an `update` method. The fallback calls |
| 22 | +`fit`. |
16 | 23 |
|
17 | 24 | """
|
18 |
| -each model interface may overload the `update` refitting method |
| 25 | +update(m::Model, verbosity, fitresult, cache, data...) = |
| 26 | + fit(m, verbosity, data...) |
| 27 | + |
| 28 | +# to support online learning in the future: |
| 29 | +# https://github.com/alan-turing-institute/MLJ.jl/issues/60 : |
| 30 | +function update_data end |
| 31 | + |
19 | 32 | """
|
20 |
| -update(m::Model, verb::Integer, fitres, cache, a...) = fit(m, verb, a...) |
| 33 | + MLJModelInterface.reformat(model, args...) -> data |
| 34 | +
|
| 35 | +Models optionally overload `reformat` to define transformations of |
| 36 | +user-supplied data into some model-specific representation (e.g., from |
| 37 | +a table to a matrix). When implemented, the MLJ user can avoid |
| 38 | +repeating such transformations unnecessarily, and can additionally |
| 39 | +make use of more efficient row subsampling, which is then based on the |
| 40 | +model-specific representation of data, rather than the |
| 41 | +user-representation. When `reformat` is overloaded, |
| 42 | +`selectrows(::Model, ...)` must be as well (see |
| 43 | +[`selectrows`](@ref)). Furthermore, the model `fit` method(s), and |
| 44 | +operations, such as `predict` and `transform`, must be refactored to |
| 45 | +act on the model-specific representions of the data. |
| 46 | +
|
| 47 | +To implement the `reformat` data front-end for a model, refer to |
| 48 | +"Implementing a data front-end" in the [MLJ |
| 49 | +manual](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/). |
| 50 | +
|
21 | 51 |
|
22 | 52 | """
|
23 |
| -each model interface may overload the `update_data` refitting method for online learning |
| 53 | +reformat(model::Model, args...) = args |
| 54 | + |
24 | 55 | """
|
25 |
| -function update_data end |
| 56 | + selectrows(::Model, I, data...) -> sampled_data |
| 57 | +
|
| 58 | +A model overloads `selectrows` whenever it buys into the optional |
| 59 | +`reformat` front-end for data preprocessing. See [`reformat`](@ref) |
| 60 | +for details. The fallback assumes `data` is a tuple and calls |
| 61 | +`selectrows(X, I)` for each `X` in `data`, returning the results in a |
| 62 | +new tuple of the same length. This call makes sense when `X` is a |
| 63 | +table, abstract vector or abstract matrix. In the last two cases, a |
| 64 | +new object and *not* a view is returned. |
| 65 | +
|
| 66 | +""" |
| 67 | +selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data) |
| 68 | + |
| 69 | +# this operation can be optionally overloaded to provide access to |
| 70 | +# fitted parameters (eg, coeficients of linear model): |
| 71 | +""" |
| 72 | + fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple |
| 73 | +
|
| 74 | +Models may overload `fitted_params`. The fallback returns |
| 75 | +`(fitresult=fitresult,)`. |
| 76 | +
|
| 77 | +Other training-related outcomes should be returned in the `report` |
| 78 | +part of the tuple returned by `fit`. |
| 79 | +
|
| 80 | +""" |
| 81 | +fitted_params(::Model, fitresult) = (fitresult=fitresult,) |
26 | 82 |
|
27 | 83 | """
|
28 |
| -supervised methods must implement the `predict` operation |
| 84 | +
|
| 85 | + predict(model, fitresult, new_data...) |
| 86 | +
|
| 87 | +`Supervised` models must implement the `predict` operation. Here |
| 88 | +`new_data` is the output of `reformat` called on user-specified data. |
| 89 | +
|
29 | 90 | """
|
30 | 91 | function predict end
|
31 | 92 |
|
|
0 commit comments