@@ -5,14 +5,44 @@ using MLJBase
55using StableRNGs
66using Random
77using Tables
8- Random. seed! (1234 )
9-
10- stable_rng () = StableRNGs. StableRNG (123 )
8+ import MLJTestInterface
119
1210# load code to be tested:
1311import DecisionTree
1412using MLJDecisionTreeInterface
1513
14+ Random. seed! (1234 )
15+
16+ @testset " generic interface tests" begin
17+ @testset " regressors" begin
18+ failures, summary = MLJTestInterface. test (
19+ [DecisionTreeRegressor, RandomForestRegressor],
20+ MLJTestInterface. make_regression ()... ;
21+ mod= @__MODULE__ ,
22+ verbosity= 0 , # bump to debug
23+ throw= false , # set to true to debug
24+ )
25+ @test isempty (failures)
26+ end
27+ @testset " classifiers" begin
28+ for data in [
29+ MLJTestInterface. make_binary (),
30+ MLJTestInterface. make_multiclass (),
31+ ]
32+ failures, summary = MLJTestInterface. test (
33+ [DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier],
34+ data... ;
35+ mod= @__MODULE__ ,
36+ verbosity= 0 , # bump to debug
37+ throw= false , # set to true to debug
38+ )
39+ @test isempty (failures)
40+ end
41+ end
42+ end
43+
44+ stable_rng () = StableRNGs. StableRNG (123 )
45+
1646# get some test data:
1747Xraw, yraw = @load_iris
1848X = Tables. matrix (Xraw);
@@ -100,13 +130,13 @@ y1 = Xraw.x2;
100130y2 = float .(Xraw. x3);
101131
102132rgs = DecisionTreeRegressor (rng= stable_rng ())
103- X, y, features = MMI . reformat (rgs, Xraw, y2)
133+ X, y, features = MLJBase . reformat (rgs, Xraw, y2)
104134
105135fitresult, _, _ = MLJBase. fit (rgs, 1 , X, y, features)
106136@test rms (predict (rgs, fitresult, X), y) < 1.5
107137
108138clf = DecisionTreeClassifier ()
109- X, y, features, _classes = MMI . reformat (clf, Xraw, y1)
139+ X, y, features, _classes = MLJBase . reformat (clf, Xraw, y1)
110140
111141fitresult, _, _ = MLJBase. fit (clf, 1 , X, y, features, _classes)
112142@test sum (predict (clf, fitresult, X) .== y1) == 0 # perfect prediction
0 commit comments