Skip to content

Commit 9fbfad6

Browse files
authored
For a 0.2.3 release (#27)
* Implement StableRNGs throughout tests to fix reproducibility problems in CI (#26) * add StableRNGs * bump compat DecisionTree = "0.11" * fix an invalid test * add rng as hyper-parameter for AdaBoostStumpClassifier * update docstring * add reproducibility test for AdaBoostStumpClassifier * srng -> stable_rng to improve readability of code * bump 0.2.3
1 parent 673d494 commit 9fbfad6

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJDecisionTreeInterface"
22
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
@@ -10,15 +10,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1111

1212
[compat]
13-
DecisionTree = "0.10"
13+
DecisionTree = "0.11"
1414
MLJModelInterface = "1.4"
1515
Tables = "1.6"
1616
julia = "1.6"
1717

1818
[extras]
1919
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
2020
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
21+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2223

2324
[targets]
24-
test = ["CategoricalArrays", "MLJBase", "Test"]
25+
test = ["CategoricalArrays", "MLJBase", "StableRNGs", "Test"]

src/MLJDecisionTreeInterface.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ end
156156

157157
MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
158158
n_iter::Int = 10::(_ ≥ 1)
159+
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
159160
end
160161

161162
function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
@@ -165,8 +166,8 @@ function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
165166
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
166167
integers_seen = MMI.int(classes_seen)
167168

168-
stumps, coefs = DT.build_adaboost_stumps(yplain, Xmatrix,
169-
m.n_iter)
169+
stumps, coefs =
170+
DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng)
170171
cache = nothing
171172
report = NamedTuple()
172173
return (stumps, coefs, classes_seen, integers_seen), cache, report
@@ -586,6 +587,7 @@ Train the machine with `fit!(mach, rows=...)`.
586587
587588
- `n_iter=10`: number of iterations of AdaBoost
588589
590+
- `rng=Random.GLOBAL_RNG`: random number generator or seed
589591
590592
# Operations
591593

test/runtests.jl

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@ using Test
22
import CategoricalArrays
33
import CategoricalArrays.categorical
44
using MLJBase
5+
using StableRNGs
56
using Random
67
Random.seed!(1234)
78

9+
stable_rng() = StableRNGs.StableRNG(123)
10+
811
# load code to be tested:
912
import DecisionTree
1013
using MLJDecisionTreeInterface
1114

1215
# get some test data:
1316
X, y = @load_iris
1417

15-
baretree = DecisionTreeClassifier()
18+
baretree = DecisionTreeClassifier(rng=stable_rng())
1619

1720
baretree.max_depth = 1
1821
fitresult, cache, report = MLJBase.fit(baretree, 2, X, y);
@@ -50,13 +53,17 @@ using Random: seed!
5053
seed!(0)
5154

5255
n,m = 10^3, 5;
53-
raw_features = rand(n,m);
54-
weights = rand(-1:1,m);
56+
raw_features = rand(stable_rng(), n,m);
57+
weights = rand(stable_rng(), -1:1,m);
5558
labels = raw_features * weights;
5659
features = MLJBase.table(raw_features);
5760

58-
R1Tree = DecisionTreeRegressor(min_samples_leaf=5, merge_purity_threshold=0.1)
59-
R2Tree = DecisionTreeRegressor(min_samples_split=5)
61+
R1Tree = DecisionTreeRegressor(
62+
min_samples_leaf=5,
63+
merge_purity_threshold=0.1,
64+
rng=stable_rng(),
65+
)
66+
R2Tree = DecisionTreeRegressor(min_samples_split=5, rng=stable_rng())
6067
model1, = MLJBase.fit(R1Tree,1, features, labels)
6168

6269
vals1 = MLJBase.predict(R1Tree,model1,features)
@@ -75,11 +82,15 @@ vals2 = MLJBase.predict(R2Tree, model2, features)
7582
## TEST ON ORDINAL FEATURES OTHER THAN CONTINUOUS
7683

7784
N = 20
78-
X = (x1=rand(N), x2=categorical(rand("abc", N), ordered=true), x3=collect(1:N))
85+
X = (
86+
x1=rand(stable_rng(),N),
87+
x2=categorical(rand(stable_rng(), "abc", N), ordered=true),
88+
x3=collect(1:N),
89+
)
7990
yfinite = X.x2
8091
ycont = float.(X.x3)
8192

82-
rgs = DecisionTreeRegressor()
93+
rgs = DecisionTreeRegressor(rng=stable_rng())
8394
fitresult, _, _ = MLJBase.fit(rgs, 1, X, ycont)
8495
@test rms(predict(rgs, fitresult, X), ycont) < 1.5
8596

@@ -90,10 +101,10 @@ fitresult, _, _ = MLJBase.fit(clf, 1, X, yfinite)
90101

91102
# -- Ensemble
92103

93-
rfc = RandomForestClassifier()
94-
abs = AdaBoostStumpClassifier()
104+
rfc = RandomForestClassifier(rng=stable_rng())
105+
abs = AdaBoostStumpClassifier(rng=stable_rng())
95106

96-
X, y = MLJBase.make_blobs(100, 3; rng=555)
107+
X, y = MLJBase.make_blobs(100, 3; rng=stable_rng())
97108

98109
m = machine(rfc, X, y)
99110
fit!(m)
@@ -103,19 +114,21 @@ m = machine(abs, X, y)
103114
fit!(m)
104115
@test accuracy(predict_mode(m, X), y) > 0.95
105116

106-
X, y = MLJBase.make_regression(rng=5124)
107-
rfr = RandomForestRegressor()
117+
X, y = MLJBase.make_regression(rng=stable_rng())
118+
rfr = RandomForestRegressor(rng=stable_rng())
108119
m = machine(rfr, X, y)
109120
fit!(m)
110121
@test rms(predict(m, X), y) < 0.4
111122

112123
N = 10
113124
function reproducibility(model, X, y, loss)
114-
model.rng = 123
115-
model.n_subfeatures = 1
125+
if !(model isa AdaBoostStumpClassifier)
126+
model.n_subfeatures = 1
127+
end
116128
mach = machine(model, X, y)
117129
train, test = partition(eachindex(y), 0.7)
118130
errs = map(1:N) do i
131+
model.rng = stable_rng()
119132
fit!(mach, rows=train, force=true, verbosity=0)
120133
yhat = predict(mach, rows=test)
121134
loss(yhat, y[test]) |> mean
@@ -124,14 +137,21 @@ function reproducibility(model, X, y, loss)
124137
end
125138

126139
@testset "reporoducibility" begin
127-
X, y = make_blobs();
140+
X, y = make_blobs(rng=stable_rng());
128141
loss = BrierLoss()
129-
for model in [DecisionTreeClassifier(), RandomForestClassifier()]
142+
for model in [
143+
DecisionTreeClassifier(),
144+
RandomForestClassifier(),
145+
AdaBoostStumpClassifier(),
146+
]
130147
@test reproducibility(model, X, y, loss)
131148
end
132-
X, y = make_regression();
149+
X, y = make_regression(rng=stable_rng());
133150
loss = LPLoss(p=2)
134-
for model in [DecisionTreeRegressor(), RandomForestRegressor()]
151+
for model in [
152+
DecisionTreeRegressor(),
153+
RandomForestRegressor(),
154+
]
135155
@test reproducibility(model, X, y, loss)
136156
end
137157
end

0 commit comments

Comments
 (0)