|
1 | 1 | using StatsLearnModels
|
2 |
| -using MLJ, MLJDecisionTreeInterface |
3 | 2 | using TableTransforms
|
4 | 3 | using DataFrames
|
5 | 4 | using Random
|
6 | 5 | using Test
|
7 | 6 |
|
| 7 | +import MLJ, MLJDecisionTreeInterface |
| 8 | + |
8 | 9 | const SLM = StatsLearnModels
|
9 | 10 |
|
10 | 11 | @testset "StatsLearnModels.jl" begin
|
| 12 | + iris = DataFrame(MLJ.load_iris()) |
| 13 | + input = iris[:, Not(:target)] |
| 14 | + output = iris[:, [:target]] |
| 15 | + train, test = MLJ.partition(1:nrow(input), 0.7, rng=123) |
| 16 | + |
11 | 17 | @testset "interface" begin
|
12 |
| - Random.seed!(123) |
13 |
| - iris = DataFrame(load_iris()) |
14 |
| - input = iris[:, Not(:target)] |
15 |
| - output = iris[:, [:target]] |
16 |
| - train, test = partition(1:nrow(input), 0.7, rng=123) |
17 |
| - Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0) |
18 |
| - fmodel = SLM.fit(Tree(), input[train, :], output[train, :]) |
19 |
| - pred = SLM.predict(fmodel, input[test, :]) |
20 |
| - accuracy = count(pred.target .== output.target[test]) / length(test) |
21 |
| - @test accuracy > 0.9 |
| 18 | + @testset "MLJ" begin |
| 19 | + Random.seed!(123) |
| 20 | + Tree = MLJ.@load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0) |
| 21 | + fmodel = SLM.fit(Tree(), input[train, :], output[train, :]) |
| 22 | + pred = SLM.predict(fmodel, input[test, :]) |
| 23 | + accuracy = count(pred.target .== output.target[test]) / length(test) |
| 24 | + @test accuracy > 0.9 |
| 25 | + end |
| 26 | + |
| 27 | + @testset "DecisionTree" begin |
| 28 | + Random.seed!(123) |
| 29 | + model = DecisionTreeClassifier() |
| 30 | + fmodel = SLM.fit(model, input[train, :], output[train, :]) |
| 31 | + pred = SLM.predict(fmodel, input[test, :]) |
| 32 | + accuracy = count(pred.target .== output.target[test]) / length(test) |
| 33 | + @test accuracy > 0.9 |
| 34 | + end |
22 | 35 | end
|
23 | 36 |
|
24 | 37 | @testset "Learn" begin
|
25 | 38 | Random.seed!(123)
|
26 |
| - iris = DataFrame(load_iris()) |
27 | 39 | outcol = :target
|
28 | 40 | incols = setdiff(propertynames(iris), [outcol])
|
29 |
| - train, test = partition(1:nrow(iris), 0.7, rng=123) |
30 |
| - Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0) |
31 |
| - transform = Learn(iris[train, :], Tree(), incols => outcol) |
| 41 | + model = DecisionTreeClassifier() |
| 42 | + transform = Learn(iris[train, :], model, incols => outcol) |
32 | 43 | @test !isrevertible(transform)
|
33 | 44 | pred = transform(iris[test, :])
|
34 | 45 | accuracy = count(pred.target .== iris.target[test]) / length(test)
|
35 | 46 | @test accuracy > 0.9
|
36 | 47 |
|
37 | 48 | # throws
|
38 | 49 | # training data is not a table
|
39 |
| - @test_throws ArgumentError Learn(nothing, Tree(), incols => outcol) |
| 50 | + @test_throws ArgumentError Learn(nothing, model, incols => outcol) |
40 | 51 | end
|
41 | 52 | end
|
0 commit comments