Skip to content

Commit 996baac

Browse files
authored
Merge pull request #78 from alan-turing-institute/dev
For a 0.3.7 release
2 parents 8469bd9 + 1bd56dc commit 996baac

File tree

8 files changed

+136
-34
lines changed

8 files changed

+136
-34
lines changed

.github/workflows/TagBot.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
name: TagBot
22
on:
3-
schedule:
4-
- cron: 0 * * * *
3+
issue_comment:
4+
types:
5+
- created
6+
workflow_dispatch:
57
jobs:
68
TagBot:
9+
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
710
runs-on: ubuntu-latest
811
steps:
912
- uses: JuliaRegistries/TagBot@v1
1013
with:
1114
token: ${{ secrets.GITHUB_TOKEN }}
15+
ssh: ${{ secrets.DOCUMENTER_KEY }}

.github/workflows/ci.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: CI
2+
on:
3+
pull_request:
4+
branches:
5+
- master
6+
- dev
7+
push:
8+
branches:
9+
- master
10+
- dev
11+
tags: '*'
12+
jobs:
13+
test:
14+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
15+
runs-on: ${{ matrix.os }}
16+
strategy:
17+
fail-fast: false
18+
matrix:
19+
version:
20+
- '1.0'
21+
- '1'
22+
os:
23+
- ubuntu-latest
24+
arch:
25+
- x64
26+
steps:
27+
- uses: actions/checkout@v2
28+
- uses: julia-actions/setup-julia@v1
29+
with:
30+
version: ${{ matrix.version }}
31+
arch: ${{ matrix.arch }}
32+
- uses: actions/cache@v1
33+
env:
34+
cache-name: cache-artifacts
35+
with:
36+
path: ~/.julia/artifacts
37+
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
38+
restore-keys: |
39+
${{ runner.os }}-test-${{ env.cache-name }}-
40+
${{ runner.os }}-test-
41+
${{ runner.os }}-
42+
- uses: julia-actions/julia-buildpkg@v1
43+
- uses: julia-actions/julia-runtest@v1
44+
- uses: julia-actions/julia-processcoverage@v1
45+
- uses: codecov/codecov-action@v1
46+
with:
47+
file: lcov.info

.travis.yml

Lines changed: 0 additions & 16 deletions
This file was deleted.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.3.6"
4+
version = "0.3.7"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ machine learning models into
55
[MLJ](https://github.com/alan-turing-institute/MLJ.jl).
66

77

8-
| [MacOS/Linux] | Coverage |
8+
| Linux | Coverage |
99
| :-----------: | :------: |
10-
| [![Build Status](https://travis-ci.org/alan-turing-institute/MLJModelInterface.jl.svg?branch=master)](https://travis-ci.org/alan-turing-institute/MLJModelInterface.jl) | [![codecov.io](http://codecov.io/github/alan-turing-institute/MLJModelInterface.jl/coverage.svg?branch=master)](http://codecov.io/github/alan-turing-institute/MLJModelInterface.jl?branch=master) |
10+
| [![Build Status](https://github.com/alan-turing-institute/MLJModelInterface.jl/workflows/CI/badge.svg)](https://github.com/alan-turing-institute/MLJModelInterface.jl/actions) | [![codecov.io](http://codecov.io/github/alan-turing-institute/MLJModelInterface.jl/coverage.svg?branch=master)](http://codecov.io/github/alan-turing-institute/MLJModelInterface.jl?branch=master) |
1111

1212

1313
[MLJ](https://github.com/alan-turing-institute/MLJ.jl) is a framework

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ export @mlj_model, metadata_pkg, metadata_model
2525
# model api
2626
export fit, update, update_data, transform, inverse_transform,
2727
fitted_params, predict, predict_mode, predict_mean, predict_median,
28-
predict_joint, evaluate, clean!
28+
predict_joint, evaluate, clean!, reformat
2929

3030
# model traits
3131
export input_scitype, output_scitype, target_scitype,

src/model_api.jl

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,92 @@
11
"""
2-
every model interface must implement a `fit` method of the form
3-
`fit(model, verb::Integer, training_args...) -> fitresult, cache, report`
2+
fit(model, verbosity, data...) -> fitresult, cache, report
3+
4+
All models must implement a `fit` method. Here `data` is the
5+
output of `reformat` on user-provided data, or some some resampling
6+
thereof. The fallback of `reformat` returns the user-provided data
7+
(eg, a table).
8+
49
"""
510
function fit end
611

712
# fallback for static transformations
8-
fit(::Static, ::Integer, a...) = (nothing, nothing, nothing)
13+
fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
914

1015
# fallbacks for supervised models that don't support sample weights:
11-
fit(m::Supervised, verb::Integer, X, y, w) = fit(m, verb, X, y)
16+
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1217

13-
# this operation can be optionally overloaded to provide access to
14-
# fitted parameters (eg, coeficients of linear model):
15-
fitted_params(::Model, fitres) = (fitresult=fitres,)
18+
"""
19+
update(model, verbosity, fitresult, cache, data...)
20+
21+
Models may optionally implement an `update` method. The fallback calls
22+
`fit`.
1623
1724
"""
18-
each model interface may overload the `update` refitting method
25+
update(m::Model, verbosity, fitresult, cache, data...) =
26+
fit(m, verbosity, data...)
27+
28+
# to support online learning in the future:
29+
# https://github.com/alan-turing-institute/MLJ.jl/issues/60 :
30+
function update_data end
31+
1932
"""
20-
update(m::Model, verb::Integer, fitres, cache, a...) = fit(m, verb, a...)
33+
MLJModelInterface.reformat(model, args...) -> data
34+
35+
Models optionally overload `reformat` to define transformations of
36+
user-supplied data into some model-specific representation (e.g., from
37+
a table to a matrix). When implemented, the MLJ user can avoid
38+
repeating such transformations unnecessarily, and can additionally
39+
make use of more efficient row subsampling, which is then based on the
40+
model-specific representation of data, rather than the
41+
user-representation. When `reformat` is overloaded,
42+
`selectrows(::Model, ...)` must be as well (see
43+
[`selectrows`](@ref)). Furthermore, the model `fit` method(s), and
44+
operations, such as `predict` and `transform`, must be refactored to
45+
act on the model-specific representions of the data.
46+
47+
To implement the `reformat` data front-end for a model, refer to
48+
"Implementing a data front-end" in the [MLJ
49+
manual](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/).
50+
2151
2252
"""
23-
each model interface may overload the `update_data` refitting method for online learning
53+
reformat(model::Model, args...) = args
54+
2455
"""
25-
function update_data end
56+
selectrows(::Model, I, data...) -> sampled_data
57+
58+
A model overloads `selectrows` whenever it buys into the optional
59+
`reformat` front-end for data preprocessing. See [`reformat`](@ref)
60+
for details. The fallback assumes `data` is a tuple and calls
61+
`selectrows(X, I)` for each `X` in `data`, returning the results in a
62+
new tuple of the same length. This call makes sense when `X` is a
63+
table, abstract vector or abstract matrix. In the last two cases, a
64+
new object and *not* a view is returned.
65+
66+
"""
67+
selectrows(::Model, I, data...) = map(X -> selectrows(X, I), data)
68+
69+
# this operation can be optionally overloaded to provide access to
70+
# fitted parameters (eg, coeficients of linear model):
71+
"""
72+
fitted_params(model, fitresult) -> human_readable_fitresult # named_tuple
73+
74+
Models may overload `fitted_params`. The fallback returns
75+
`(fitresult=fitresult,)`.
76+
77+
Other training-related outcomes should be returned in the `report`
78+
part of the tuple returned by `fit`.
79+
80+
"""
81+
fitted_params(::Model, fitresult) = (fitresult=fitresult,)
2682

2783
"""
28-
supervised methods must implement the `predict` operation
84+
85+
predict(model, fitresult, new_data...)
86+
87+
`Supervised` models must implement the `predict` operation. Here
88+
`new_data` is the output of `reformat` called on user-specified data.
89+
2990
"""
3091
function predict end
3192

test/model_api.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ end
77

88
mutable struct APIx1 <: Static end
99

10+
@testset "selectrows(model, data...)" begin
11+
X = (x1 = [2, 4, 6],)
12+
y = [10.0, 20.0, 30.0]
13+
@test selectrows(APIx0(), 2:3, X, y) == ((x1 = [4, 6],), [20.0, 30.0])
14+
end
15+
1016
@testset "fit-x" begin
1117
m0 = APIx0(f0=1)
1218
m1 = APIx0b(f0=3)

0 commit comments

Comments
 (0)