|
1 | 1 | """
|
2 |
| -every model interface must implement a `fit` method of the form |
3 |
| -`fit(model, verb::Integer, training_args...) -> fitresult, cache, report` |
| 2 | + fit(model, verbosity, data...) -> fitresult, cache, report |
| 3 | +
|
| 4 | +All models must implement a `fit` method. Here `data` is the |
| 5 | +output of `reformat` on user-provided data, or some some resampling |
| 6 | +thereof. The fallback of `reformat` returns the user-provided data |
| 7 | +(eg, a table). |
| 8 | +
|
4 | 9 | """
|
5 | 10 | function fit end
|
6 | 11 |
|
7 | 12 | # fallback for static transformations
|
8 |
| -fit(::Static, ::Integer, a...) = (nothing, nothing, nothing) |
| 13 | +fit(::Static, ::Integer, data...) = (nothing, nothing, nothing) |
9 | 14 |
|
10 | 15 | # fallbacks for supervised models that don't support sample weights:
|
11 |
| -fit(m::Supervised, verb::Integer, X, y, w) = fit(m, verb, X, y) |
| 16 | +fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y) |
12 | 17 |
|
13 |
| -# this operation can be optionally overloaded to provide access to |
14 |
| -# fitted parameters (eg, coeficients of linear model): |
15 |
| -fitted_params(::Model, fitres) = (fitresult=fitres,) |
| 18 | +""" |
| 19 | + update(model, verbosity, fitresult, cache, data...) |
| 20 | +
|
| 21 | +Models may optionally implement an `update` method. The fallback calls |
| 22 | +`fit`. |
16 | 23 |
|
17 | 24 | """
|
18 |
| -each model interface may overload the `update` refitting method |
| 25 | +update(m::Model, verbosity, fitresult, cache, data...) = |
| 26 | + fit(m, verbosity, data...) |
| 27 | + |
| 28 | +# to support online learning in the future: |
| 29 | +# https://github.com/alan-turing-institute/MLJ.jl/issues/60 : |
| 30 | +function update_data end |
| 31 | + |
19 | 32 | """
|
20 |
| -update(m::Model, verb::Integer, fitres, cache, a...) = fit(m, verb, a...) |
| 33 | + MLJModelInterface.reformat(model, args...) -> data |
| 34 | +
|
| 35 | +Models optionally overload `reformat` to define transformations of |
| 36 | +user-supplied training data `args` into some model-specific |
| 37 | +representation `data` (e.g., from a table to a matrix). The fallback |
| 38 | +returns `args` (no transformation). |
| 39 | +
|
| 40 | +If `mach` is a machine with `mach.model == model` then calling `fit!(mach)` |
| 41 | +will either call |
| 42 | +
|
| 43 | + fit(model, verbosity, data...) |
| 44 | +
|
| 45 | +or |
| 46 | +
|
| 47 | + update(model, verbosity, data...) |
| 48 | +
|
| 49 | +where `data = reformat(model, mach.args...)`. This means that |
| 50 | +overloading `reformat` alters the form of the arguments expected by |
| 51 | +`fit` and `update`. |
| 52 | +
|
| 53 | +If, instead, one calls `fit!(mach, rows=I)`, then `data` in the above |
| 54 | +`fit`/`update` calls is replaced with `selectrows(model, I, |
| 55 | +data...)`. So overloading `reformat` generally requires overloading of |
| 56 | +`selectrows` also, to specify how the model-specific representation of |
| 57 | +training data is resampled. |
| 58 | +
|
| 59 | +Note `data` is always a tuple, but it needn't have the same length as |
| 60 | +`args`. |
21 | 61 |
|
22 | 62 | """
|
23 |
| -each model interface may overload the `update_data` refitting method for online learning |
| 63 | +reformat(model::Model, args...) = args |
| 64 | + |
24 | 65 | """
|
25 |
| -function update_data end |
| 66 | + selectrows(::Model, I, data...) -> sampled_data |
| 67 | +
|
| 68 | +Models optionally overload `selectrows` for efficient resampling of |
| 69 | +training data. Here `data` is the ouput of calling `reformat` on |
| 70 | +user-provided data. The fallback assumes `data` is a tuple and calls |
| 71 | +`selectrows(X, I)` for each `X` in `data`, returning the results in a |
| 72 | +new tuple of the same length. This call makes sense when `X` is a |
| 73 | +table, abstract vector or abstract matrix. In the last two cases, a |
| 74 | +new object and *not* a view is returned. |
| 75 | +
|
| 76 | +""" |
| 77 | +selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data) |
| 78 | + |
| 79 | +# this operation can be optionally overloaded to provide access to |
| 80 | +# fitted parameters (eg, coeficients of linear model): |
| 81 | +""" |
| 82 | + fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple |
| 83 | +
|
| 84 | +Models may overload `fitted_params`. The fallback returns |
| 85 | +`(fitresult=fitresult,)`. |
| 86 | +
|
| 87 | +Other training-related outcomes should be returned in the `report` |
| 88 | +part of the tuple returned by `fit`. |
| 89 | +
|
| 90 | +""" |
| 91 | +fitted_params(::Model, fitresult) = (fitresult=fitresult,) |
26 | 92 |
|
27 | 93 | """
|
28 |
| -supervised methods must implement the `predict` operation |
| 94 | +
|
| 95 | + predict(model, fitresult, new_data...) |
| 96 | +
|
| 97 | +`Supervised` models must implement the `predict` operation. Here |
| 98 | +`new_data` is the output of `reformat` called on user-specified data. |
| 99 | +
|
29 | 100 | """
|
30 | 101 | function predict end
|
31 | 102 |
|
|
0 commit comments