Skip to content

Commit d286415

Browse files
eliascarvjuliohm
andauthored
Add DecisionTree.jl models (#4)
* Add DecisionTree.jl models * Update src/StatsLearnModels.jl --------- Co-authored-by: Júlio Hoffimann <[email protected]>
1 parent a4f8322 commit d286415

File tree

4 files changed

+68
-17
lines changed

4 files changed

+68
-17
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.2.0"
55

66
[deps]
77
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
8+
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
89
TableTransforms = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e"
910
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1011

@@ -16,6 +17,7 @@ StatsLearnModelsMLJModelInterfaceExt = "MLJModelInterface"
1617

1718
[compat]
1819
ColumnSelectors = "0.1"
20+
DecisionTree = "0.12"
1921
MLJModelInterface = "1.9"
2022
TableTransforms = "1.15"
2123
Tables = "1.11"

src/StatsLearnModels.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,23 @@ using ColumnSelectors: selector
99
using TableTransforms: StatelessFeatureTransform
1010
import TableTransforms: applyfeat, isrevertible
1111

12+
import DecisionTree as DT
13+
using DecisionTree: AdaBoostStumpClassifier, DecisionTreeClassifier, RandomForestClassifier
14+
using DecisionTree: DecisionTreeRegressor, RandomForestRegressor
15+
1216
include("interface.jl")
17+
include("models/decisiontree.jl")
1318
include("learn.jl")
1419

15-
export Learn
20+
export
21+
# transform
22+
Learn,
23+
24+
# models
25+
AdaBoostStumpClassifier,
26+
DecisionTreeClassifier,
27+
RandomForestClassifier,
28+
DecisionTreeRegressor,
29+
RandomForestRegressor
1630

1731
end

src/models/decisiontree.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
const DTModel = Union{
2+
AdaBoostStumpClassifier,
3+
DecisionTreeClassifier,
4+
RandomForestClassifier,
5+
DecisionTreeRegressor,
6+
RandomForestRegressor
7+
}
8+
9+
function fit(model::DTModel, input, output)
10+
cols = Tables.columns(output)
11+
names = Tables.columnnames(cols)
12+
outcol = first(names)
13+
y = Tables.getcolumn(cols, outcol)
14+
X = Tables.matrix(input)
15+
DT.fit!(model, X, y)
16+
FittedModel(model, outcol)
17+
end
18+
19+
function predict(fmodel::FittedModel{<:DTModel}, table)
20+
outcol = fmodel.cache
21+
X = Tables.matrix(table)
22+
= DT.predict(fmodel.model, X)
23+
(; outcol => ŷ) |> Tables.materializer(table)
24+
end

test/runtests.jl

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,52 @@
11
using StatsLearnModels
2-
using MLJ, MLJDecisionTreeInterface
32
using TableTransforms
43
using DataFrames
54
using Random
65
using Test
76

7+
import MLJ, MLJDecisionTreeInterface
8+
89
const SLM = StatsLearnModels
910

1011
@testset "StatsLearnModels.jl" begin
12+
iris = DataFrame(MLJ.load_iris())
13+
input = iris[:, Not(:target)]
14+
output = iris[:, [:target]]
15+
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
16+
1117
@testset "interface" begin
12-
Random.seed!(123)
13-
iris = DataFrame(load_iris())
14-
input = iris[:, Not(:target)]
15-
output = iris[:, [:target]]
16-
train, test = partition(1:nrow(input), 0.7, rng=123)
17-
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
18-
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
19-
pred = SLM.predict(fmodel, input[test, :])
20-
accuracy = count(pred.target .== output.target[test]) / length(test)
21-
@test accuracy > 0.9
18+
@testset "MLJ" begin
19+
Random.seed!(123)
20+
Tree = MLJ.@load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
21+
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
22+
pred = SLM.predict(fmodel, input[test, :])
23+
accuracy = count(pred.target .== output.target[test]) / length(test)
24+
@test accuracy > 0.9
25+
end
26+
27+
@testset "DecisionTree" begin
28+
Random.seed!(123)
29+
model = DecisionTreeClassifier()
30+
fmodel = SLM.fit(model, input[train, :], output[train, :])
31+
pred = SLM.predict(fmodel, input[test, :])
32+
accuracy = count(pred.target .== output.target[test]) / length(test)
33+
@test accuracy > 0.9
34+
end
2235
end
2336

2437
@testset "Learn" begin
2538
Random.seed!(123)
26-
iris = DataFrame(load_iris())
2739
outcol = :target
2840
incols = setdiff(propertynames(iris), [outcol])
29-
train, test = partition(1:nrow(iris), 0.7, rng=123)
30-
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
31-
transform = Learn(iris[train, :], Tree(), incols => outcol)
41+
model = DecisionTreeClassifier()
42+
transform = Learn(iris[train, :], model, incols => outcol)
3243
@test !isrevertible(transform)
3344
pred = transform(iris[test, :])
3445
accuracy = count(pred.target .== iris.target[test]) / length(test)
3546
@test accuracy > 0.9
3647

3748
# throws
3849
# training data is not a table
39-
@test_throws ArgumentError Learn(nothing, Tree(), incols => outcol)
50+
@test_throws ArgumentError Learn(nothing, model, incols => outcol)
4051
end
4152
end

0 commit comments

Comments
 (0)