Skip to content

Commit 5ef9346

Browse files
committed
add reformat method/fallback; add selectrows(::Model, ...) fallback
1 parent 2c358e9 commit 5ef9346

File tree

2 files changed

+89
-12
lines changed

2 files changed

+89
-12
lines changed

src/model_api.jl

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,102 @@
11
"""
2-
every model interface must implement a `fit` method of the form
3-
`fit(model, verb::Integer, training_args...) -> fitresult, cache, report`
2+
fit(model, verbosity, data...) -> fitresult, cache, report
3+
4+
All models must implement a `fit` method. Here `data` is the
5+
output of `reformat` on user-provided data, or some some resampling
6+
thereof. The fallback of `reformat` returns the user-provided data
7+
(eg, a table).
8+
49
"""
510
function fit end
611

712
# fallback for static transformations
8-
fit(::Static, ::Integer, a...) = (nothing, nothing, nothing)
13+
fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
914

1015
# fallbacks for supervised models that don't support sample weights:
11-
fit(m::Supervised, verb::Integer, X, y, w) = fit(m, verb, X, y)
16+
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1217

13-
# this operation can be optionally overloaded to provide access to
14-
# fitted parameters (eg, coeficients of linear model):
15-
fitted_params(::Model, fitres) = (fitresult=fitres,)
18+
"""
19+
update(model, verbosity, fitresult, cache, data...)
20+
21+
Models may optionally implement an `update` method. The fallback calls
22+
`fit`.
1623
1724
"""
18-
each model interface may overload the `update` refitting method
25+
update(m::Model, verbosity, fitresult, cache, data...) =
26+
fit(m, verbosity, data...)
27+
28+
# to support online learning in the future:
29+
# https://github.com/alan-turing-institute/MLJ.jl/issues/60 :
30+
function update_data end
31+
1932
"""
20-
update(m::Model, verb::Integer, fitres, cache, a...) = fit(m, verb, a...)
33+
MLJModelInterface.reformat(model, args...) -> data
34+
35+
Models optionally overload `reformat` to define transformations of
36+
user-supplied training data `args` into some model-specific
37+
representation `data` (e.g., from a table to a matrix). The fallback
38+
returns `args` (no transformation).
39+
40+
If `mach` is a machine with `mach.model == model` then calling `fit!(mach)`
41+
will either call
42+
43+
fit(model, verbosity, data...)
44+
45+
or
46+
47+
update(model, verbosity, data...)
48+
49+
where `data = reformat(model, mach.args...)`. This means that
50+
overloading `reformat` alters the form of the arguments expected by
51+
`fit` and `update`.
52+
53+
If, instead, one calls `fit!(mach, rows=I)`, then `data` in the above
54+
`fit`/`update` calls is replaced with `selectrows(model, I,
55+
data...)`. So overloading `reformat` generally requires overloading of
56+
`selectrows` also, to specify how the model-specific representation of
57+
training data is resampled.
58+
59+
Note `data` is always a tuple, but it needn't have the same length as
60+
`args`.
2161
2262
"""
23-
each model interface may overload the `update_data` refitting method for online learning
63+
reformat(model::Model, args...) = args
64+
2465
"""
25-
function update_data end
66+
selectrows(::Model, I, data...) -> sampled_data
67+
68+
Models optionally overload `selectrows` for efficient resampling of
69+
training data. Here `data` is the ouput of calling `reformat` on
70+
user-provided data. The fallback assumes `data` is a tuple and calls
71+
`selectrows(X, I)` for each `X` in `data`, returning the results in a
72+
new tuple of the same length. This call makes sense when `X` is a
73+
table, abstract vector or abstract matrix. In the last two cases, a
74+
new object and *not* a view is returned.
75+
76+
"""
77+
selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data)
78+
79+
# this operation can be optionally overloaded to provide access to
80+
# fitted parameters (eg, coeficients of linear model):
81+
"""
82+
fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple
83+
84+
Models may overload `fitted_params`. The fallback returns
85+
`(fitresult=fitresult,)`.
86+
87+
Other training-related outcomes should be returned in the `report`
88+
part of the tuple returned by `fit`.
89+
90+
"""
91+
fitted_params(::Model, fitresult) = (fitresult=fitresult,)
2692

2793
"""
28-
supervised methods must implement the `predict` operation
94+
95+
predict(model, fitresult, new_data...)
96+
97+
`Supervised` models must implement the `predict` operation. Here
98+
`new_data` is the output of `reformat` called on user-specified data.
99+
29100
"""
30101
function predict end
31102

test/model_api.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ end
77

88
mutable struct APIx1 <: Static end
99

10+
@testset "selectrows(model, data...)" begin
11+
X = (x1 = [2, 4, 6],)
12+
y = [10.0, 20.0, 30.0]
13+
@test selectrows(APIx0(), 2:3, X, y) == ((x1 = [4, 6],), [20.0, 30.0])
14+
end
15+
1016
@testset "fit-x" begin
1117
m0 = APIx0(f0=1)
1218
m1 = APIx0b(f0=3)

0 commit comments

Comments
 (0)