Skip to content

Commit 39cdbbe

Browse files
committed
add generic interface tests from MLJTestInterface.jl
typo
1 parent 98677fc commit 39cdbbe

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ julia = "1.6"
1818

1919
[extras]
2020
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
21+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
2122
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2223
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2324

2425
[targets]
25-
test = ["MLJBase", "StableRNGs", "Test"]
26+
test = ["MLJBase", "MLJTestInterface", "StableRNGs", "Test"]

test/runtests.jl

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,44 @@ using MLJBase
55
using StableRNGs
66
using Random
77
using Tables
8-
Random.seed!(1234)
9-
10-
stable_rng() = StableRNGs.StableRNG(123)
8+
import MLJTestInterface
119

1210
# load code to be tested:
1311
import DecisionTree
1412
using MLJDecisionTreeInterface
1513

14+
Random.seed!(1234)
15+
16+
@testset "generic interface tests" begin
17+
@testset "regressors" begin
18+
failures, summary = MLJTestInterface.test(
19+
[DecisionTreeRegressor, RandomForestRegressor],
20+
MLJTestInterface.make_regression()...;
21+
mod=@__MODULE__,
22+
verbosity=0, # bump to debug
23+
throw=false, # set to true to debug
24+
)
25+
@test isempty(failures)
26+
end
27+
@testset "classifiers" begin
28+
for data in [
29+
MLJTestInterface.make_binary(),
30+
MLJTestInterface.make_multiclass(),
31+
]
32+
failures, summary = MLJTestInterface.test(
33+
[DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier],
34+
data...;
35+
mod=@__MODULE__,
36+
verbosity=0, # bump to debug
37+
throw=false, # set to true to debug
38+
)
39+
@test isempty(failures)
40+
end
41+
end
42+
end
43+
44+
stable_rng() = StableRNGs.StableRNG(123)
45+
1646
# get some test data:
1747
Xraw, yraw = @load_iris
1848
X = Tables.matrix(Xraw);
@@ -100,13 +130,13 @@ y1 = Xraw.x2;
100130
y2 = float.(Xraw.x3);
101131

102132
rgs = DecisionTreeRegressor(rng=stable_rng())
103-
X, y, features = MMI.reformat(rgs, Xraw, y2)
133+
X, y, features = MLJBase.reformat(rgs, Xraw, y2)
104134

105135
fitresult, _, _ = MLJBase.fit(rgs, 1, X, y, features)
106136
@test rms(predict(rgs, fitresult, X), y) < 1.5
107137

108138
clf = DecisionTreeClassifier()
109-
X, y, features, _classes = MMI.reformat(clf, Xraw, y1)
139+
X, y, features, _classes = MLJBase.reformat(clf, Xraw, y1)
110140

111141
fitresult, _, _ = MLJBase.fit(clf, 1, X, y, features, _classes)
112142
@test sum(predict(clf, fitresult, X) .== y1) == 0 # perfect prediction

0 commit comments

Comments
 (0)