@@ -2,17 +2,20 @@ using Test
22import CategoricalArrays
33import CategoricalArrays. categorical
44using MLJBase
5+ using StableRNGs
56using Random
67Random. seed! (1234 )
78
9+ stable_rng () = StableRNGs. StableRNG (123 )
10+
811# load code to be tested:
912import DecisionTree
1013using MLJDecisionTreeInterface
1114
1215# get some test data:
1316X, y = @load_iris
1417
15- baretree = DecisionTreeClassifier ()
18+ baretree = DecisionTreeClassifier (rng = stable_rng () )
1619
1720baretree. max_depth = 1
1821fitresult, cache, report = MLJBase. fit (baretree, 2 , X, y);
@@ -50,13 +53,17 @@ using Random: seed!
5053seed! (0 )
5154
5255n,m = 10 ^ 3 , 5 ;
53- raw_features = rand (n,m);
54- weights = rand (- 1 : 1 ,m);
56+ raw_features = rand (stable_rng (), n,m);
57+ weights = rand (stable_rng (), - 1 : 1 ,m);
5558labels = raw_features * weights;
5659features = MLJBase. table (raw_features);
5760
58- R1Tree = DecisionTreeRegressor (min_samples_leaf= 5 , merge_purity_threshold= 0.1 )
59- R2Tree = DecisionTreeRegressor (min_samples_split= 5 )
61+ R1Tree = DecisionTreeRegressor (
62+ min_samples_leaf= 5 ,
63+ merge_purity_threshold= 0.1 ,
64+ rng= stable_rng (),
65+ )
66+ R2Tree = DecisionTreeRegressor (min_samples_split= 5 , rng= stable_rng ())
6067model1, = MLJBase. fit (R1Tree,1 , features, labels)
6168
6269vals1 = MLJBase. predict (R1Tree,model1,features)
@@ -75,11 +82,15 @@ vals2 = MLJBase.predict(R2Tree, model2, features)
7582# # TEST ON ORDINAL FEATURES OTHER THAN CONTINUOUS
7683
7784N = 20
78- X = (x1= rand (N), x2= categorical (rand (" abc" , N), ordered= true ), x3= collect (1 : N))
85+ X = (
86+ x1= rand (stable_rng (),N),
87+ x2= categorical (rand (stable_rng (), " abc" , N), ordered= true ),
88+ x3= collect (1 : N),
89+ )
7990yfinite = X. x2
8091ycont = float .(X. x3)
8192
82- rgs = DecisionTreeRegressor ()
93+ rgs = DecisionTreeRegressor (rng = stable_rng () )
8394fitresult, _, _ = MLJBase. fit (rgs, 1 , X, ycont)
8495@test rms (predict (rgs, fitresult, X), ycont) < 1.5
8596
@@ -90,10 +101,10 @@ fitresult, _, _ = MLJBase.fit(clf, 1, X, yfinite)
90101
91102# -- Ensemble
92103
93- rfc = RandomForestClassifier ()
94- abs = AdaBoostStumpClassifier ()
104+ rfc = RandomForestClassifier (rng = stable_rng () )
105+ abs = AdaBoostStumpClassifier (rng = stable_rng () )
95106
96- X, y = MLJBase. make_blobs (100 , 3 ; rng= 555 )
107+ X, y = MLJBase. make_blobs (100 , 3 ; rng= stable_rng () )
97108
98109m = machine (rfc, X, y)
99110fit! (m)
@@ -103,19 +114,21 @@ m = machine(abs, X, y)
103114fit! (m)
104115@test accuracy (predict_mode (m, X), y) > 0.95
105116
106- X, y = MLJBase. make_regression (rng= 5124 )
107- rfr = RandomForestRegressor ()
117+ X, y = MLJBase. make_regression (rng= stable_rng () )
118+ rfr = RandomForestRegressor (rng = stable_rng () )
108119m = machine (rfr, X, y)
109120fit! (m)
110121@test rms (predict (m, X), y) < 0.4
111122
112123N = 10
113124function reproducibility (model, X, y, loss)
114- model. rng = 123
115- model. n_subfeatures = 1
125+ if ! (model isa AdaBoostStumpClassifier)
126+ model. n_subfeatures = 1
127+ end
116128 mach = machine (model, X, y)
117129 train, test = partition (eachindex (y), 0.7 )
118130 errs = map (1 : N) do i
131+ model. rng = stable_rng ()
119132 fit! (mach, rows= train, force= true , verbosity= 0 )
120133 yhat = predict (mach, rows= test)
121134 loss (yhat, y[test]) |> mean
@@ -124,14 +137,21 @@ function reproducibility(model, X, y, loss)
124137end
125138
126139@testset " reporoducibility" begin
127- X, y = make_blobs ();
140+ X, y = make_blobs (rng = stable_rng () );
128141 loss = BrierLoss ()
129- for model in [DecisionTreeClassifier (), RandomForestClassifier ()]
142+ for model in [
143+ DecisionTreeClassifier (),
144+ RandomForestClassifier (),
145+ AdaBoostStumpClassifier (),
146+ ]
130147 @test reproducibility (model, X, y, loss)
131148 end
132- X, y = make_regression ();
149+ X, y = make_regression (rng = stable_rng () );
133150 loss = LPLoss (p= 2 )
134- for model in [DecisionTreeRegressor (), RandomForestRegressor ()]
151+ for model in [
152+ DecisionTreeRegressor (),
153+ RandomForestRegressor (),
154+ ]
135155 @test reproducibility (model, X, y, loss)
136156 end
137157end
0 commit comments