Skip to content

Commit 52daf6e

Browse files
authored
Add Learn transform (#1)
* Add 'Learn' transform * Add compat * Fix typo * Update tests
1 parent e7f945d commit 52daf6e

File tree

6 files changed

+122
-36
lines changed

6 files changed

+122
-36
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ authors = ["Elias Carvalho <[email protected]> and contributors"]
44
version = "0.1.0"
55

66
[deps]
7+
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
8+
TableTransforms = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e"
79
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
810

911
[weakdeps]
@@ -13,6 +15,8 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1315
StatsLearnModelsMLJModelInterfaceExt = "MLJModelInterface"
1416

1517
[compat]
18+
ColumnSelectors = "0.1"
1619
MLJModelInterface = "1.9"
20+
TableTransforms = "1.15"
1721
Tables = "1.11"
1822
julia = "1.9"

src/StatsLearnModels.jl

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,13 @@
11
module StatsLearnModels
22

3-
"""
4-
StatsLearnModels.fit(model, input, output) -> FittedModel
3+
using Tables
4+
using ColumnSelectors: selector
5+
using TableTransforms: StatelessFeatureTransform
6+
import TableTransforms: applyfeat, isrevertible
57

6-
Fit statistical learning `model` using features in `input` table
7-
and targets in `output` table. Returns a fitted model with all
8-
the necessary information for prediction with the `predict` function.
9-
"""
10-
function fit end
8+
export Learn
119

12-
"""
13-
StatsLearnModels.predict(model::FittedModel, table)
14-
15-
Predict the target values using the fitted statistical learning `model`
16-
and a new `table` of features.
17-
"""
18-
function predict end
19-
20-
"""
21-
StatsLearnModels.FittedModel(model, cache)
22-
23-
Wrapper type used to save learning model and auxiliary
24-
variables needed for prediction.
25-
"""
26-
struct FittedModel{M,C}
27-
model::M
28-
cache::C
29-
end
10+
include("interface.jl")
11+
include("learn.jl")
3012

3113
end

src/interface.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
StatsLearnModels.fit(model, input, output) -> FittedModel
3+
4+
Fit statistical learning `model` using features in `input` table
5+
and targets in `output` table. Returns a fitted model with all
6+
the necessary information for prediction with the `predict` function.
7+
"""
8+
function fit end
9+
10+
"""
11+
StatsLearnModels.predict(model::FittedModel, table)
12+
13+
Predict the target values using the fitted statistical learning `model`
14+
and a new `table` of features.
15+
"""
16+
function predict end
17+
18+
"""
19+
StatsLearnModels.FittedModel(model, cache)
20+
21+
Wrapper type used to save learning model and auxiliary
22+
variables needed for prediction.
23+
"""
24+
struct FittedModel{M,C}
25+
model::M
26+
cache::C
27+
end

src/learn.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
Learn(model, train, incols, outcols)
3+
4+
Fits the statistical learning `model` using the input columns, selected by `incols`,
5+
and the output columns, selected by `outcols`, from the `train` table.
6+
7+
The column selection can be a single column identifier (index or name),
8+
a collection of identifiers or a regular expression (regex).
9+
10+
# Examples
11+
12+
```julia
13+
Learn(model, train, [1, 2, 3], "d")
14+
Learn(model, train, [:a, :b, :c], :d)
15+
Learn(model, train, ["a", "b", "c"], 4)
16+
Learn(model, train, [1, 2, 3], [:d, :e])
17+
Learn(model, train, r"[abc]", ["d", "e"])
18+
```
19+
"""
20+
struct Learn{M<:FittedModel} <: StatelessFeatureTransform
21+
model::M
22+
input::Vector{Symbol}
23+
end
24+
25+
function Learn(model, train, incols, outcols)
26+
if !Tables.istable(train)
27+
throw(ArgumentError("training data must be a table"))
28+
end
29+
30+
cols = Tables.columns(train)
31+
names = Tables.columnnames(cols)
32+
innms = selector(incols)(names)
33+
outnms = selector(outcols)(names)
34+
35+
input = (; (nm => Tables.getcolumn(cols, nm) for nm in innms)...)
36+
output = (; (nm => Tables.getcolumn(cols, nm) for nm in outnms)...)
37+
38+
fmodel = fit(model, input, output)
39+
Learn(fmodel, innms)
40+
end
41+
42+
isrevertible(::Type{<:Learn}) = false
43+
44+
function applyfeat(transform::Learn, feat, prep)
45+
cols = Tables.columns(feat)
46+
pairs = (nm => Tables.getcolumn(cols, nm) for nm in transform.input)
47+
test = (; pairs...) |> Tables.materializer(feat)
48+
predict(transform.model, test), nothing
49+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
33
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
44
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
55
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
TableTransforms = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,41 @@
1-
import StatsLearnModels as SLM
1+
using StatsLearnModels
22
using MLJ, MLJDecisionTreeInterface
3+
using TableTransforms
34
using DataFrames
45
using Random
56
using Test
67

8+
const SLM = StatsLearnModels
9+
710
@testset "StatsLearnModels.jl" begin
8-
Random.seed!(123)
9-
iris = DataFrame(load_iris())
10-
input = iris[:, Not(:target)]
11-
output = iris[:, [:target]]
12-
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
13-
train, test = partition(1:nrow(input), 0.7, rng=123)
14-
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
15-
pred = SLM.predict(fmodel, input[test, :])
16-
accuracy = count(pred.target .== output.target[test]) / length(test)
17-
@test accuracy > 0.9
11+
@testset "interface" begin
12+
Random.seed!(123)
13+
iris = DataFrame(load_iris())
14+
input = iris[:, Not(:target)]
15+
output = iris[:, [:target]]
16+
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
17+
train, test = partition(1:nrow(input), 0.7, rng=123)
18+
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
19+
pred = SLM.predict(fmodel, input[test, :])
20+
accuracy = count(pred.target .== output.target[test]) / length(test)
21+
@test accuracy > 0.9
22+
end
23+
24+
@testset "Learn" begin
25+
Random.seed!(123)
26+
iris = DataFrame(load_iris())
27+
outcol = :target
28+
incols = setdiff(propertynames(iris), [outcol])
29+
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
30+
train, test = partition(1:nrow(iris), 0.7, rng=123)
31+
transform = Learn(Tree(), iris[train, :], incols, outcol)
32+
@test !isrevertible(transform)
33+
pred = transform(iris[test, :])
34+
accuracy = count(pred.target .== iris.target[test]) / length(test)
35+
@test accuracy > 0.9
36+
37+
# throws
38+
# training data is not a table
39+
@test_throws ArgumentError Learn(Tree(), nothing, incols, outcol)
40+
end
1841
end

0 commit comments

Comments
 (0)