Skip to content

Commit 72fc1b5

Browse files
committed
Refactor tests
1 parent f6d7f51 commit 72fc1b5

File tree

1 file changed

+54
-37
lines changed

1 file changed

+54
-37
lines changed

test/runtests.jl

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,25 @@ import MLJ, MLJDecisionTreeInterface
1313
const SLM = StatsLearnModels
1414

1515
@testset "StatsLearnModels.jl" begin
16-
iris = DataFrame(MLJ.load_iris())
17-
input = iris[:, Not(:target)]
18-
output = iris[:, [:target]]
19-
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
20-
21-
@testset "show" begin
16+
@testset "Basic" begin
17+
# show method
18+
x1 = rand(1:0.1:10, 100)
19+
x2 = rand(1:0.1:10, 100)
20+
y = 2x1 + x2
21+
input = DataFrame(; x1, x2)
22+
output = DataFrame(; y)
23+
train, test = 1:70, 71:100
2224
model = DecisionTreeClassifier()
2325
fmodel = SLM.fit(model, input[train, :], output[train, :])
2426
@test sprint(show, fmodel) == "FittedStatsLearnModel{DecisionTreeClassifier}"
25-
end
2627

27-
@testset "StatsLearnModel" begin
28+
# show method
29+
lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
30+
@test sprint(show, lmodel) == """
31+
StatsLearnModel{DecisionTreeClassifier}
32+
├─ features: [:a, :b]
33+
└─ targets: :c"""
34+
2835
# accessor functions
2936
model = DecisionTreeClassifier()
3037
feats = selector([:a, :b])
@@ -33,46 +40,26 @@ const SLM = StatsLearnModels
3340
@test lmodel.model === model
3441
@test lmodel.feats === feats
3542
@test lmodel.targs === targs
36-
# show
37-
lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
38-
@test sprint(show, lmodel) == """
39-
StatsLearnModel{DecisionTreeClassifier}
40-
├─ features: [:a, :b]
41-
└─ targets: :c"""
4243
end
4344

44-
@testset "models" begin
45-
@testset "MLJ" begin
46-
Random.seed!(123)
47-
Tree = MLJ.@load(DecisionTreeClassifier, pkg = DecisionTree, verbosity = 0)
48-
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
49-
pred = SLM.predict(fmodel, input[test, :])
50-
accuracy = count(pred.target .== output.target[test]) / length(test)
51-
@test accuracy > 0.9
52-
end
53-
54-
@testset "DecisionTree" begin
55-
Random.seed!(123)
56-
model = DecisionTreeClassifier()
57-
fmodel = SLM.fit(model, input[train, :], output[train, :])
58-
pred = SLM.predict(fmodel, input[test, :])
59-
accuracy = count(pred.target .== output.target[test]) / length(test)
60-
@test accuracy > 0.9
61-
end
62-
45+
@testset "Models" begin
6346
@testset "NearestNeighbors" begin
6447
Random.seed!(123)
48+
iris = DataFrame(MLJ.load_iris())
49+
input = iris[:, Not(:target)]
50+
output = iris[:, [:target]]
51+
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
6552
model = KNNClassifier(5)
6653
fmodel = SLM.fit(model, input[train, :], output[train, :])
6754
pred = SLM.predict(fmodel, input[test, :])
6855
accuracy = count(pred.target .== output.target[test]) / length(test)
6956
@test accuracy > 0.9
7057

7158
Random.seed!(123)
72-
a = rand(1:0.1:10, 100)
73-
b = rand(1:0.1:10, 100)
74-
y = 2a + b
75-
input = DataFrame(; a, b)
59+
x1 = rand(1:0.1:10, 100)
60+
x2 = rand(1:0.1:10, 100)
61+
y = 2x1 + x2
62+
input = DataFrame(; x1, x2)
7663
output = DataFrame(; y)
7764
model = KNNRegressor(5)
7865
fmodel = SLM.fit(model, input, output)
@@ -101,10 +88,27 @@ const SLM = StatsLearnModels
10188
pred = SLM.predict(fmodel, input)
10289
@test all(isapprox.(pred.y, output.y, atol=0.5))
10390
end
91+
92+
@testset "DecisionTree" begin
93+
Random.seed!(123)
94+
iris = DataFrame(MLJ.load_iris())
95+
input = iris[:, Not(:target)]
96+
output = iris[:, [:target]]
97+
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
98+
model = DecisionTreeClassifier()
99+
fmodel = SLM.fit(model, input[train, :], output[train, :])
100+
pred = SLM.predict(fmodel, input[test, :])
101+
accuracy = count(pred.target .== output.target[test]) / length(test)
102+
@test accuracy > 0.9
103+
end
104104
end
105105

106106
@testset "Learn" begin
107107
Random.seed!(123)
108+
iris = DataFrame(MLJ.load_iris())
109+
input = iris[:, Not(:target)]
110+
output = iris[:, [:target]]
111+
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
108112
outvar = :target
109113
feats = setdiff(propertynames(iris), [outvar])
110114
targs = outvar
@@ -119,4 +123,17 @@ const SLM = StatsLearnModels
119123
# training data is not a table
120124
@test_throws ArgumentError Learn(nothing, model, feats => targs)
121125
end
126+
127+
@testset "MLJ" begin
128+
Random.seed!(123)
129+
iris = DataFrame(MLJ.load_iris())
130+
input = iris[:, Not(:target)]
131+
output = iris[:, [:target]]
132+
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
133+
Tree = MLJ.@load(DecisionTreeClassifier, pkg = DecisionTree, verbosity = 0)
134+
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
135+
pred = SLM.predict(fmodel, input[test, :])
136+
accuracy = count(pred.target .== output.target[test]) / length(test)
137+
@test accuracy > 0.9
138+
end
122139
end

0 commit comments

Comments
 (0)