Skip to content

Commit bbe133c

Browse files
authored
Merge pull request #91 from alan-turing-institute/iteration
Add support for iterative models
2 parents 59016f8 + 33a9d3e commit bbe133c

File tree

5 files changed

+30
-7
lines changed

5 files changed

+30
-7
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -10,7 +10,7 @@ StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
1010

1111
[compat]
1212
ScientificTypes = "^1"
13-
StatisticalTraits = "^0.1"
13+
StatisticalTraits = "^0.1.1"
1414
julia = "^1"
1515

1616
[extras]

src/MLJModelInterface.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ const MODEL_TRAITS = [
2121
:implemented_methods,
2222
:hyperparameters,
2323
:hyperparameter_types,
24-
:hyperparameter_ranges]
24+
:hyperparameter_ranges,
25+
:iteration_parameter,
26+
:supports_training_losses]
2527

2628
# ------------------------------------------------------------------------
2729
# Dependencies (ScientificTypes and StatisticalTraits have none)
@@ -49,7 +51,7 @@ export @mlj_model, metadata_pkg, metadata_model
4951
# model api
5052
export fit, update, update_data, transform, inverse_transform,
5153
fitted_params, predict, predict_mode, predict_mean, predict_median,
52-
predict_joint, evaluate, clean!, reformat
54+
predict_joint, evaluate, clean!, reformat, training_losses
5355

5456
# model traits
5557
for trait in MODEL_TRAITS

src/model_api.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
fit(model, verbosity, data...) -> fitresult, cache, report
2+
MLJModelInterface.fit(model, verbosity, data...) -> fitresult, cache, report
33
44
All models must implement a `fit` method. Here `data` is the
55
output of `reformat` on user-provided data, or some some resampling
@@ -16,7 +16,7 @@ fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
1616
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1717

1818
"""
19-
update(model, verbosity, fitresult, cache, data...)
19+
MLJModelInterface.update(model, verbosity, fitresult, cache, data...)
2020
2121
Models may optionally implement an `update` method. The fallback calls
2222
`fit`.
@@ -25,6 +25,20 @@ Models may optionally implement an `update` method. The fallback calls
2525
update(m::Model, verbosity, fitresult, cache, data...) =
2626
fit(m, verbosity, data...)
2727

28+
"""
29+
MLJModelInterface.training_losses(model::M, report)
30+
31+
If `M` is an iterative model type which calculates training losses,
32+
implement this method to return an `AbstractVector` of the losses
33+
in historical order. If the model calculates scores instead, then the
34+
sign of the scores should be reversed.
35+
36+
The following trait overload is alse required:
37+
`supports_training_losses(::Type{<:M}) = true`
38+
39+
"""
40+
training_losses(model, report) = nothing
41+
2842
# to support online learning in the future:
2943
# https://github.com/alan-turing-institute/MLJ.jl/issues/60 :
3044
function update_data end
@@ -53,7 +67,7 @@ manual](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_gen
5367
reformat(model::Model, args...) = args
5468

5569
"""
56-
selectrows(::Model, I, data...) -> sampled_data
70+
MLJModelInterface.selectrows(::Model, I, data...) -> sampled_data
5771
5872
A model overloads `selectrows` whenever it buys into the optional
5973
`reformat` front-end for data preprocessing. See [`reformat`](@ref)

test/model_api.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@ end
2424
@test fit(m1, 1, randn(2), randn(2), 5) == (7, nothing, nothing)
2525
# default fitted params
2626
@test M.fitted_params(m1, 7) == (fitresult=7,)
27+
# default iteration_parameter
28+
@test M.training_losses(m0, nothing) === nothing
2729
# static
2830
s1 = APIx1()
2931
@test fit(s1, 1, 0) == (nothing, nothing, nothing)
3032

3133
#update fallback = fit
3234
@test update(m0, 1, 5, nothing, randn(2), 5) == (5, nothing, nothing)
35+
36+
# training losses:
37+
f, c, r = MLJModelInterface.fit(m0, 1, rand(2), rand(2))
38+
@test M.training_losses(m0, r) === nothing
3339
end
3440

3541
struct DummyUnivariateFinite end

test/model_traits.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ bar(::P1) = nothing
3939
@test is_wrapper(ms) == false
4040
@test supports_online(ms) == false
4141
@test supports_weights(ms) == false
42+
@test iteration_parameter(ms) === nothing
4243

4344
@test hyperparameter_ranges(md) == (nothing,)
4445

0 commit comments

Comments
 (0)