@@ -4,11 +4,12 @@ import MLJModelInterface
44import MLJModelInterface: @mlj_model , metadata_pkg, metadata_model,
55 Table, Continuous, Count, Finite, OrderedFactor,
66 Multiclass
7-
8- const MMI = MLJModelInterface
9-
107import DecisionTree
118
9+ using Random
10+ import Random. GLOBAL_RNG
11+
12+ const MMI = MLJModelInterface
1213const DT = DecisionTree
1314const PKG = " MLJDecisionTreeInterface"
1415
@@ -73,10 +74,12 @@ from the DecisionTree.jl algorithm).
7374- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having `>=thresh`
7475 combined purity
7576
76- - `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
77-
7877- `display_depth=5`: max depth to show when displaying the tree
7978
79+ - `rng=Random.GLOBAL_RNG`: random number generator or seed
80+
81+ - `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
82+
8083"""
8184@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
8285 max_depth:: Int = (- )(1 ):: (_ ≥ -1)
@@ -88,6 +91,7 @@ from the DecisionTree.jl algorithm).
8891 merge_purity_threshold:: Float64 = 1.0 :: (_ ≤ 1)
8992 pdf_smoothing:: Float64 = 0.0 :: (0 ≤ _ ≤ 1)
9093 display_depth:: Int = 5 :: (_ ≥ 1)
94+ rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
9195end
9296
9397function MMI. fit (m:: DecisionTreeClassifier , verbosity:: Int , X, y)
@@ -102,7 +106,8 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
102106 m. max_depth,
103107 m. min_samples_leaf,
104108 m. min_samples_split,
105- m. min_purity_increase)
109+ m. min_purity_increase,
110+ rng= m. rng)
106111 if m. post_prune
107112 tree = DT. prune_tree (tree, m. merge_purity_threshold)
108113 end
@@ -167,8 +172,11 @@ $RFC_DESCR
167172
168173- `sampling_fraction=0.7` fraction of samples to train each tree on
169174
175+ - `rng=Random.GLOBAL_RNG`: random number generator or seed
176+
170177- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
171178
179+
172180"""
173181@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
174182 max_depth:: Int = (- )(1 ):: (_ ≥ -1)
@@ -179,6 +187,7 @@ $RFC_DESCR
179187 n_trees:: Int = 10 :: (_ ≥ 2)
180188 sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
181189 pdf_smoothing:: Float64 = 0.0 :: (0 ≤ _ ≤ 1)
190+ rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
182191end
183192
184193function MMI. fit (m:: RandomForestClassifier , verbosity:: Int , X, y)
@@ -195,7 +204,8 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
195204 m. max_depth,
196205 m. min_samples_leaf,
197206 m. min_samples_split,
198- m. min_purity_increase)
207+ m. min_purity_increase;
208+ rng= m. rng)
199209 cache = nothing
200210 report = NamedTuple ()
201211 return (forest, classes_seen, integers_seen), cache, report
@@ -280,6 +290,9 @@ are Deterministic.
280290
281291- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having `>=thresh`
282292 combined purity
293+
294+ - `rng=Random.GLOBAL_RNG`: random number generator or seed
295+
283296"""
284297@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
285298 max_depth:: Int = (- )(1 ):: (_ ≥ -1)
@@ -289,6 +302,7 @@ are Deterministic.
289302 n_subfeatures:: Int = 0 :: (_ ≥ -1)
290303 post_prune:: Bool = false
291304 merge_purity_threshold:: Float64 = 1.0 :: (0 ≤ _ ≤ 1)
305+ rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
292306end
293307
294308function MMI. fit (m:: DecisionTreeRegressor , verbosity:: Int , X, y)
@@ -298,7 +312,8 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
298312 m. max_depth,
299313 m. min_samples_leaf,
300314 m. min_samples_split,
301- m. min_purity_increase)
315+ m. min_purity_increase;
316+ rng= m. rng)
302317
303318 if m. post_prune
304319 tree = DT. prune_tree (tree, m. merge_purity_threshold)
@@ -337,6 +352,8 @@ $RFC_DESCR
337352
338353- `sampling_fraction=0.7` fraction of samples to train each tree on
339354
355+ - `rng=Random.GLOBAL_RNG`: random number generator or seed
356+
340357- `pdf_smoothing=0.0`: threshold for smoothing the predicted scores
341358
342359"""
@@ -349,6 +366,7 @@ $RFC_DESCR
349366 n_trees:: Int = 10 :: (_ ≥ 2)
350367 sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
351368 pdf_smoothing:: Float64 = 0.0 :: (0 ≤ _ ≤ 1)
369+ rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
352370end
353371
354372function MMI. fit (m:: RandomForestRegressor , verbosity:: Int , X, y)
@@ -360,7 +378,8 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
360378 m. max_depth,
361379 m. min_samples_leaf,
362380 m. min_samples_split,
363- m. min_purity_increase)
381+ m. min_purity_increase,
382+ rng= m. rng)
364383 cache = nothing
365384 report = NamedTuple ()
366385 return forest, cache, report
0 commit comments