Skip to content

Commit 1bd56dc

Browse files
authored
Merge pull request #76 from alan-turing-institute/mlj2
Add facility for adding a data front-end to a model implementation
2 parents 2c358e9 + 9c937d8 commit 1bd56dc

File tree

4 files changed

+81
-14
lines changed

4 files changed

+81
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.3.6"
4+
version = "0.3.7"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ export @mlj_model, metadata_pkg, metadata_model
2525
# model api
2626
export fit, update, update_data, transform, inverse_transform,
2727
fitted_params, predict, predict_mode, predict_mean, predict_median,
28-
predict_joint, evaluate, clean!
28+
predict_joint, evaluate, clean!, reformat
2929

3030
# model traits
3131
export input_scitype, output_scitype, target_scitype,

src/model_api.jl

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,92 @@
11
"""
2-
every model interface must implement a `fit` method of the form
3-
`fit(model, verb::Integer, training_args...) -> fitresult, cache, report`
2+
fit(model, verbosity, data...) -> fitresult, cache, report
3+
4+
All models must implement a `fit` method. Here `data` is the
5+
output of `reformat` on user-provided data, or some some resampling
6+
thereof. The fallback of `reformat` returns the user-provided data
7+
(eg, a table).
8+
49
"""
510
function fit end
611

712
# fallback for static transformations
8-
fit(::Static, ::Integer, a...) = (nothing, nothing, nothing)
13+
fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
914

1015
# fallbacks for supervised models that don't support sample weights:
11-
fit(m::Supervised, verb::Integer, X, y, w) = fit(m, verb, X, y)
16+
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1217

13-
# this operation can be optionally overloaded to provide access to
14-
# fitted parameters (eg, coeficients of linear model):
15-
fitted_params(::Model, fitres) = (fitresult=fitres,)
18+
"""
19+
update(model, verbosity, fitresult, cache, data...)
20+
21+
Models may optionally implement an `update` method. The fallback calls
22+
`fit`.
1623
1724
"""
18-
each model interface may overload the `update` refitting method
25+
update(m::Model, verbosity, fitresult, cache, data...) =
26+
fit(m, verbosity, data...)
27+
28+
# to support online learning in the future:
29+
# https://github.com/alan-turing-institute/MLJ.jl/issues/60 :
30+
function update_data end
31+
1932
"""
20-
update(m::Model, verb::Integer, fitres, cache, a...) = fit(m, verb, a...)
33+
MLJModelInterface.reformat(model, args...) -> data
34+
35+
Models optionally overload `reformat` to define transformations of
36+
user-supplied data into some model-specific representation (e.g., from
37+
a table to a matrix). When implemented, the MLJ user can avoid
38+
repeating such transformations unnecessarily, and can additionally
39+
make use of more efficient row subsampling, which is then based on the
40+
model-specific representation of data, rather than the
41+
user-representation. When `reformat` is overloaded,
42+
`selectrows(::Model, ...)` must be as well (see
43+
[`selectrows`](@ref)). Furthermore, the model `fit` method(s), and
44+
operations, such as `predict` and `transform`, must be refactored to
45+
act on the model-specific representions of the data.
46+
47+
To implement the `reformat` data front-end for a model, refer to
48+
"Implementing a data front-end" in the [MLJ
49+
manual](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/).
50+
2151
2252
"""
23-
each model interface may overload the `update_data` refitting method for online learning
53+
reformat(model::Model, args...) = args
54+
2455
"""
25-
function update_data end
56+
selectrows(::Model, I, data...) -> sampled_data
57+
58+
A model overloads `selectrows` whenever it buys into the optional
59+
`reformat` front-end for data preprocessing. See [`reformat`](@ref)
60+
for details. The fallback assumes `data` is a tuple and calls
61+
`selectrows(X, I)` for each `X` in `data`, returning the results in a
62+
new tuple of the same length. This call makes sense when `X` is a
63+
table, abstract vector or abstract matrix. In the last two cases, a
64+
new object and *not* a view is returned.
65+
66+
"""
67+
selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data)
68+
69+
# this operation can be optionally overloaded to provide access to
70+
# fitted parameters (eg, coeficients of linear model):
71+
"""
72+
fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple
73+
74+
Models may overload `fitted_params`. The fallback returns
75+
`(fitresult=fitresult,)`.
76+
77+
Other training-related outcomes should be returned in the `report`
78+
part of the tuple returned by `fit`.
79+
80+
"""
81+
fitted_params(::Model, fitresult) = (fitresult=fitresult,)
2682

2783
"""
28-
supervised methods must implement the `predict` operation
84+
85+
predict(model, fitresult, new_data...)
86+
87+
`Supervised` models must implement the `predict` operation. Here
88+
`new_data` is the output of `reformat` called on user-specified data.
89+
2990
"""
3091
function predict end
3192

test/model_api.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ end
77

88
mutable struct APIx1 <: Static end
99

10+
@testset "selectrows(model, data...)" begin
11+
X = (x1 = [2, 4, 6],)
12+
y = [10.0, 20.0, 30.0]
13+
@test selectrows(APIx0(), 2:3, X, y) == ((x1 = [4, 6],), [20.0, 30.0])
14+
end
15+
1016
@testset "fit-x" begin
1117
m0 = APIx0(f0=1)
1218
m1 = APIx0b(f0=3)

0 commit comments

Comments
 (0)