Skip to content

Commit 5c56538

Browse files
committed
Add tests
1 parent e7acaa7 commit 5c56538

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
[deps]
2+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
3+
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
4+
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
5+
StatsLearnModels = "c146b59d-1589-421c-8e09-a22e554fd05c"
26
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1-
using StatsLearnModels
1+
import StatsLearnModels as SLM
2+
using MLJ, MLJDecisionTreeInterface
3+
using DataFrames
24
using Test
35

46
@testset "StatsLearnModels.jl" begin
5-
# Write your tests here.
7+
iris = DataFrame(load_iris())
8+
input = iris[:, Not(:target)]
9+
output = iris[:, [:target]]
10+
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
11+
train, test = partition(1:nrow(input), 0.7, rng=123)
12+
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
13+
pred = SLM.predict(fmodel, input[test, :])
14+
accuracy = count(pred.target .== output.target[test]) / length(test)
15+
@test accuracy > 0.9
616
end

0 commit comments

Comments
 (0)