Skip to content

Commit e9e18a3

Browse files
committed
Major cleanup
1 parent 1a99ea7 commit e9e18a3

File tree

7 files changed

+74
-83
lines changed

7 files changed

+74
-83
lines changed

ext/StatsLearnModelsMLJModelInterfaceExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ function SLM.fit(model::MI.Model, input, output)
1818
y = Tables.getcolumn(cols, target)
1919
data = MI.reformat(model, input, y)
2020
fitresult, _... = MI.fit(model, 0, data...)
21-
SLM.FittedModel(model, (fitresult, target))
21+
SLM.FittedStatsLearnModel(model, (fitresult, target))
2222
end
2323

24-
function SLM.predict(fmodel::SLM.FittedModel{<:MI.Model}, table)
24+
function SLM.predict(fmodel::SLM.FittedStatsLearnModel{<:MI.Model}, table)
2525
(; model, cache) = fmodel
2626
fitresult, target = cache
2727
data = MI.reformat(model, table)

src/interface.jl

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,71 +3,58 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
StatsLearnModels.fit(model, input, output) -> FittedModel
6+
StatsLearnModel(model, invars, outvars)
77
8-
Fit statistical learning `model` using features in `input` table
9-
and targets in `output` table. Returns a fitted model with all
10-
the necessary information for prediction with the `predict` function.
11-
"""
12-
function fit end
8+
Wrap a (possibly external) `model` with selectors of
9+
input variables `invars` and output variables `outvars`.
1310
14-
"""
15-
StatsLearnModels.predict(model::FittedModel, table)
11+
## Examples
1612
17-
Predict the target values using the fitted statistical learning `model`
18-
and a new `table` of features.
13+
```julia
14+
StatsLearnModel(DecisionTreeClassifier(), ["x1","x2"], "y")
15+
StatsLearnModel(DecisionTreeClassifier(), 1:3, "target")
16+
```
1917
"""
20-
function predict end
21-
22-
"""
23-
StatsLearnModels.FittedModel(model, cache)
24-
25-
Wrapper type used to save learning model and auxiliary
26-
variables needed for prediction.
27-
"""
28-
struct FittedModel{M,C}
18+
struct StatsLearnModel{M,I<:ColumnSelector,O<:ColumnSelector}
2919
model::M
30-
cache::C
20+
invars::I
21+
outvars::O
3122
end
3223

33-
Base.show(io::IO, ::FittedModel{M}) where {M} = print(io, "FittedModel{$(nameof(M))}")
24+
StatsLearnModel(model, invars, outvars) = StatsLearnModel(model, selector(invars), selector(outvars))
3425

3526
"""
36-
StatsLearnModels.StatsLearnModel(model, incols, outcols)
27+
fit(model, input, output)
3728
38-
Wrapper type for learning models used for dispatch purposes.
29+
Fit statistical learning `model` using features in `input` table
30+
and targets in `output` table. Returns a fitted model with all
31+
the necessary information for prediction with the `predict` function.
3932
"""
40-
struct StatsLearnModel{M,I<:ColumnSelector,O<:ColumnSelector}
41-
model::M
42-
input::I
43-
output::O
44-
end
45-
46-
StatsLearnModel(model, incols, outcols) = StatsLearnModel(model, selector(incols), selector(outcols))
33+
function fit end
4734

4835
function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
4936
println(io, "StatsLearnModel{$(nameof(M))}")
50-
println(io, "├─ input: $(model.input)")
51-
print(io, "└─ output: $(model.output)")
37+
println(io, "├─ features: $(model.invars)")
38+
print(io, "└─ targets: $(model.outvars)")
5239
end
5340

5441
"""
55-
StatsLearnModels.model(lmodel::StatsLearnModel)
56-
57-
Returns the model of the `lmodel`.
58-
"""
59-
model(lmodel::StatsLearnModel) = lmodel.model
42+
FittedStatsLearnModel(model, cache)
6043
44+
Wrap the statistical learning `model` with the `cache`
45+
produced during the [`fit`](@ref) stage.
6146
"""
62-
StatsLearnModels.input(lmodel::StatsLearnModel)
63-
64-
Returns the input column selection of the `lmodel`.
65-
"""
66-
input(lmodel::StatsLearnModel) = lmodel.input
47+
struct FittedStatsLearnModel{M,C}
48+
model::M
49+
cache::C
50+
end
6751

6852
"""
69-
StatsLearnModels.output(lmodel::StatsLearnModel)
70-
71-
Returns the output column selection of the `lmodel`.
53+
predict(model::FittedStatsLearnModel, table)
54+
55+
Predict targets using the fitted statistical
56+
learning `model` and a new `table` of features.
7257
"""
73-
output(lmodel::StatsLearnModel) = lmodel.output
58+
function predict end
59+
60+
Base.show(io::IO, ::FittedStatsLearnModel{M}) where {M} = print(io, "FittedStatsLearnModel{$(nameof(M))}")

src/learn.jl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
Learn(train, model, incols => outcols)
6+
Learn(train, model, invars => outvars)
77
8-
Fits the statistical learning `model` using the input columns, selected by `incols`,
9-
and the output columns, selected by `outcols`, from the `train` table.
10-
11-
The column selection can be a single column identifier (index or name),
12-
a collection of identifiers or a regular expression (regex).
8+
Fits the statistical learning `model` to `train` table,
9+
using the selectors of input variables `invars` and
10+
output variables `outvars`.
1311
1412
# Examples
1513
@@ -21,12 +19,12 @@ Learn(train, model, [1, 2, 3] => [:d, :e])
2119
Learn(train, model, r"[abc]" => ["d", "e"])
2220
```
2321
"""
24-
struct Learn{M<:FittedModel} <: StatelessFeatureTransform
22+
struct Learn{M<:FittedStatsLearnModel} <: StatelessFeatureTransform
2523
model::M
26-
input::Vector{Symbol}
24+
invars::Vector{Symbol}
2725
end
2826

29-
Learn(train, model, (incols, outcols)::Pair) = Learn(train, StatsLearnModel(model, incols, outcols))
27+
Learn(train, model, (invars, outvars)::Pair) = Learn(train, StatsLearnModel(model, invars, outvars))
3028

3129
function Learn(train, lmodel::StatsLearnModel)
3230
if !Tables.istable(train)
@@ -35,21 +33,26 @@ function Learn(train, lmodel::StatsLearnModel)
3533

3634
cols = Tables.columns(train)
3735
names = Tables.columnnames(cols)
38-
innms = lmodel.input(names)
39-
outnms = lmodel.output(names)
36+
invars = lmodel.invars(names)
37+
outvars = lmodel.outvars(names)
4038

41-
input = (; (nm => Tables.getcolumn(cols, nm) for nm in innms)...)
42-
output = (; (nm => Tables.getcolumn(cols, nm) for nm in outnms)...)
39+
input = (; (var => Tables.getcolumn(cols, var) for var in invars)...)
40+
output = (; (var => Tables.getcolumn(cols, var) for var in outvars)...)
4341

4442
fmodel = fit(lmodel.model, input, output)
45-
Learn(fmodel, innms)
43+
44+
Learn(fmodel, invars)
4645
end
4746

4847
isrevertible(::Type{<:Learn}) = false
4948

5049
function applyfeat(transform::Learn, feat, prep)
50+
model = transform.model
51+
vars = transform.invars
52+
5153
cols = Tables.columns(feat)
52-
pairs = (nm => Tables.getcolumn(cols, nm) for nm in transform.input)
54+
pairs = (var => Tables.getcolumn(cols, var) for var in vars)
5355
test = (; pairs...) |> Tables.materializer(feat)
54-
predict(transform.model, test), nothing
56+
57+
predict(model, test), nothing
5558
end

src/models/decisiontree.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ function fit(model::DecisionTreeModel, input, output)
1717
y = Tables.getcolumn(cols, outnm)
1818
X = Tables.matrix(input)
1919
DT.fit!(model, X, y)
20-
FittedModel(model, outnm)
20+
FittedStatsLearnModel(model, outnm)
2121
end
2222

23-
function predict(fmodel::FittedModel{<:DecisionTreeModel}, table)
23+
function predict(fmodel::FittedStatsLearnModel{<:DecisionTreeModel}, table)
2424
outnm = fmodel.cache
2525
X = Tables.matrix(table)
2626
= DT.predict(fmodel.model, X)

src/models/glm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ function fit(model::GLMModel, input, output)
2626
X = Tables.matrix(input)
2727
y = Tables.getcolumn(cols, outnm)
2828
fitted = _fit(model, X, y)
29-
FittedModel(model, (fitted, outnm))
29+
FittedStatsLearnModel(model, (fitted, outnm))
3030
end
3131

32-
function predict(fmodel::FittedModel{<:GLMModel}, table)
32+
function predict(fmodel::FittedStatsLearnModel{<:GLMModel}, table)
3333
model, outnm = fmodel.cache
3434
X = Tables.matrix(table)
3535
= GLM.predict(model, X)

src/models/nearestneighbors.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ function fit(model::NearestNeighborsModel, input, output)
3434
else
3535
NN.BallTree(data, metric; leafsize, reorder)
3636
end
37-
FittedModel(model, (tree, outnm, outcol))
37+
FittedStatsLearnModel(model, (tree, outnm, outcol))
3838
end
3939

40-
function predict(fmodel::FittedModel{<:NearestNeighborsModel}, table)
40+
function predict(fmodel::FittedStatsLearnModel{<:NearestNeighborsModel}, table)
4141
(; model, cache) = fmodel
4242
tree, outnm, outcol = cache
4343
data = Tables.matrix(table, transpose=true)

test/runtests.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,24 @@ const SLM = StatsLearnModels
2121
@testset "show" begin
2222
model = DecisionTreeClassifier()
2323
fmodel = SLM.fit(model, input[train, :], output[train, :])
24-
@test sprint(show, fmodel) == "FittedModel{DecisionTreeClassifier}"
24+
@test sprint(show, fmodel) == "FittedStatsLearnModel{DecisionTreeClassifier}"
2525
end
2626

2727
@testset "StatsLearnModel" begin
2828
# accessor functions
2929
model = DecisionTreeClassifier()
30-
incols = selector([:a, :b])
31-
outcols = selector(:c)
32-
lmodel = SLM.StatsLearnModel(model, incols, outcols)
33-
@test SLM.model(lmodel) === model
34-
@test SLM.input(lmodel) === incols
35-
@test SLM.output(lmodel) === outcols
30+
invars = selector([:a, :b])
31+
outvars = selector(:c)
32+
lmodel = SLM.StatsLearnModel(model, invars, outvars)
33+
@test lmodel.model === model
34+
@test lmodel.invars === invars
35+
@test lmodel.outvars === outvars
3636
# show
3737
lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
3838
@test sprint(show, lmodel) == """
3939
StatsLearnModel{DecisionTreeClassifier}
40-
├─ input: [:a, :b]
41-
└─ output: :c"""
40+
├─ features: [:a, :b]
41+
└─ targets: :c"""
4242
end
4343

4444
@testset "models" begin
@@ -105,17 +105,18 @@ const SLM = StatsLearnModels
105105

106106
@testset "Learn" begin
107107
Random.seed!(123)
108-
outcol = :target
109-
incols = setdiff(propertynames(iris), [outcol])
108+
outvar = :target
109+
invars = setdiff(propertynames(iris), [outvar])
110+
outvars = outvar
110111
model = DecisionTreeClassifier()
111-
transform = Learn(iris[train, :], model, incols => outcol)
112+
transform = Learn(iris[train, :], model, invars => outvars)
112113
@test !isrevertible(transform)
113114
pred = transform(iris[test, :])
114115
accuracy = count(pred.target .== iris.target[test]) / length(test)
115116
@test accuracy > 0.9
116117

117118
# throws
118119
# training data is not a table
119-
@test_throws ArgumentError Learn(nothing, model, incols => outcol)
120+
@test_throws ArgumentError Learn(nothing, model, invars => outvars)
120121
end
121122
end

0 commit comments

Comments
 (0)