Skip to content

Commit 6629619

Browse files
authored
Merge pull request #9 from JuliaAI/dev
Expose `rng` addressing #4
2 parents 8b9557e + ab1e98e commit 6629619

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "MLJDecisionTreeInterface"
22
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
88
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
9+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910

1011
[compat]
1112
DecisionTree = "0.10"
@@ -15,8 +16,7 @@ julia = "1"
1516
[extras]
1617
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
1718
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
18-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2020

2121
[targets]
22-
test = ["CategoricalArrays", "MLJBase", "Random", "Test"]
22+
test = ["CategoricalArrays", "MLJBase", "Test"]

src/MLJDecisionTreeInterface.jl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import MLJModelInterface
44
import MLJModelInterface: @mlj_model, metadata_pkg, metadata_model,
55
Table, Continuous, Count, Finite, OrderedFactor,
66
Multiclass
7-
8-
const MMI = MLJModelInterface
9-
107
import DecisionTree
118

9+
using Random
10+
import Random.GLOBAL_RNG
11+
12+
const MMI = MLJModelInterface
1213
const DT = DecisionTree
1314
const PKG = "MLJDecisionTreeInterface"
1415

@@ -73,10 +74,12 @@ from the DecisionTree.jl algorithm).
7374
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having `>=thresh`
7475
combined purity
7576
76-
- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
77-
7877
- `display_depth=5`: max depth to show when displaying the tree
7978
79+
- `rng=Random.GLOBAL_RNG`: random number generator or seed
80+
81+
- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
82+
8083
"""
8184
@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
8285
max_depth::Int = (-)(1)::(_ ≥ -1)
@@ -88,6 +91,7 @@ from the DecisionTree.jl algorithm).
8891
merge_purity_threshold::Float64 = 1.0::(_ ≤ 1)
8992
pdf_smoothing::Float64 = 0.0::(0 ≤ _ ≤ 1)
9093
display_depth::Int = 5::(_ ≥ 1)
94+
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
9195
end
9296

9397
function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
@@ -102,7 +106,8 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
102106
m.max_depth,
103107
m.min_samples_leaf,
104108
m.min_samples_split,
105-
m.min_purity_increase)
109+
m.min_purity_increase,
110+
rng=m.rng)
106111
if m.post_prune
107112
tree = DT.prune_tree(tree, m.merge_purity_threshold)
108113
end
@@ -167,8 +172,11 @@ $RFC_DESCR
167172
168173
- `sampling_fraction=0.7` fraction of samples to train each tree on
169174
175+
- `rng=Random.GLOBAL_RNG`: random number generator or seed
176+
170177
- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
171178
179+
172180
"""
173181
@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
174182
max_depth::Int = (-)(1)::(_ ≥ -1)
@@ -179,6 +187,7 @@ $RFC_DESCR
179187
n_trees::Int = 10::(_ ≥ 2)
180188
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
181189
pdf_smoothing::Float64 = 0.0::(0 ≤ _ ≤ 1)
190+
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
182191
end
183192

184193
function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
@@ -195,7 +204,8 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
195204
m.max_depth,
196205
m.min_samples_leaf,
197206
m.min_samples_split,
198-
m.min_purity_increase)
207+
m.min_purity_increase;
208+
rng=m.rng)
199209
cache = nothing
200210
report = NamedTuple()
201211
return (forest, classes_seen, integers_seen), cache, report
@@ -280,6 +290,9 @@ are Deterministic.
280290
281291
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having `>=thresh`
282292
combined purity
293+
294+
- `rng=Random.GLOBAL_RNG`: random number generator or seed
295+
283296
"""
284297
@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
285298
max_depth::Int = (-)(1)::(_ ≥ -1)
@@ -289,6 +302,7 @@ are Deterministic.
289302
n_subfeatures::Int = 0::(_ ≥ -1)
290303
post_prune::Bool = false
291304
merge_purity_threshold::Float64 = 1.0::(0 ≤ _ ≤ 1)
305+
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
292306
end
293307

294308
function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
@@ -298,7 +312,8 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
298312
m.max_depth,
299313
m.min_samples_leaf,
300314
m.min_samples_split,
301-
m.min_purity_increase)
315+
m.min_purity_increase;
316+
rng=m.rng)
302317

303318
if m.post_prune
304319
tree = DT.prune_tree(tree, m.merge_purity_threshold)
@@ -337,6 +352,8 @@ $RFC_DESCR
337352
338353
- `sampling_fraction=0.7` fraction of samples to train each tree on
339354
355+
- `rng=Random.GLOBAL_RNG`: random number generator or seed
356+
340357
- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
341358
342359
"""
@@ -349,6 +366,7 @@ $RFC_DESCR
349366
n_trees::Int = 10::(_ ≥ 2)
350367
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
351368
pdf_smoothing::Float64 = 0.0::(0 ≤ _ ≤ 1)
369+
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
352370
end
353371

354372
function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
@@ -360,7 +378,8 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
360378
m.max_depth,
361379
m.min_samples_leaf,
362380
m.min_samples_split,
363-
m.min_purity_increase)
381+
m.min_purity_increase,
382+
rng=m.rng)
364383
cache = nothing
365384
report = NamedTuple()
366385
return forest, cache, report

test/runtests.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,30 @@ rfr = RandomForestRegressor()
104104
m = machine(rfr, X, y)
105105
fit!(m)
106106
@test rms(predict(m, X), y) < 0.4
107+
108+
N = 10
109+
function reproducibility(model, X, y, loss)
110+
model.rng = 123
111+
model.n_subfeatures = 1
112+
mach = machine(model, X, y)
113+
train, test = partition(eachindex(y), 0.7)
114+
errs = map(1:N) do i
115+
fit!(mach, rows=train, force=true, verbosity=0)
116+
yhat = predict(mach, rows=test)
117+
loss(yhat, y[test]) |> mean
118+
end
119+
return length(unique(errs)) == 1
120+
end
121+
122+
@testset "reporoducibility" begin
123+
X, y = make_blobs();
124+
loss = BrierLoss()
125+
for model in [DecisionTreeClassifier(), RandomForestClassifier()]
126+
@test reproducibility(model, X, y, loss)
127+
end
128+
X, y = make_regression();
129+
loss = LPLoss(p=2)
130+
for model in [DecisionTreeRegressor(), RandomForestRegressor()]
131+
@test reproducibility(model, X, y, loss)
132+
end
133+
end

0 commit comments

Comments
 (0)