Skip to content

Commit a938e80

Browse files
authored
Merge pull request #42 from JuliaAI/update-method-implementation
Implement warm restart for random forest models
2 parents a23431e + 241d184 commit a938e80

File tree

2 files changed

+159
-3
lines changed

2 files changed

+159
-3
lines changed

src/MLJDecisionTreeInterface.jl

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,72 @@ function MMI.fit(
129129
m.min_samples_split,
130130
m.min_purity_increase;
131131
rng=m.rng)
132-
cache = nothing
132+
cache = deepcopy(m)
133133

134134
report = (features=features,)
135135

136136
return (forest, classes_seen, integers_seen), cache, report
137137
end
138138

139+
info_recomputing(n) = "Detected a change to hyperparameters " *
140+
"not restricted to `n_trees`. Recomputing all $n trees trees. "
141+
info_dropping(n) = "Dropping $n trees from the ensemble. "
142+
info_adding(n) = "Adding $n trees to the ensemble. "
143+
144+
function MMI.update(
145+
model::RandomForestClassifier,
146+
verbosity::Int,
147+
old_fitresult,
148+
old_model,
149+
Xmatrix,
150+
yplain,
151+
features,
152+
classes,
153+
)
154+
155+
only_iterations_have_changed = MMI.is_same_except(model, old_model, :n_trees)
156+
157+
if !only_iterations_have_changed
158+
verbosity > 0 && @info info_recomputing(model.n_trees)
159+
return MMI.fit(
160+
model,
161+
verbosity,
162+
Xmatrix,
163+
yplain,
164+
features,
165+
classes,
166+
)
167+
end
168+
169+
old_forest = old_fitresult[1]
170+
Δn_trees = model.n_trees - old_model.n_trees
171+
# if `n_trees` drops, then tuncate, otherwise compute more trees
172+
if Δn_trees < 0
173+
verbosity > 0 && @info info_dropping(-Δn_trees)
174+
forest = old_forest[1:model.n_trees]
175+
else
176+
verbosity > 0 && @info info_adding(Δn_trees)
177+
forest = DT.build_forest(
178+
old_forest,
179+
yplain, Xmatrix,
180+
model.n_subfeatures,
181+
model.n_trees,
182+
model.sampling_fraction,
183+
model.max_depth,
184+
model.min_samples_leaf,
185+
model.min_samples_split,
186+
model.min_purity_increase;
187+
rng=model.rng,
188+
)
189+
end
190+
191+
fitresult = (forest, old_fitresult[2:3]...)
192+
cache = deepcopy(model)
193+
report = (features=features,)
194+
return fitresult, cache, report
195+
196+
end
197+
139198
MMI.fitted_params(::RandomForestClassifier, (forest,_)) = (forest=forest,)
140199

141200
function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
@@ -145,7 +204,7 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
145204
end
146205

147206
MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true
148-
207+
MMI.iteration_parameter(::Type{<:RandomForestClassifier}) = :n_trees
149208

150209
# # ADA BOOST STUMP CLASSIFIER
151210

@@ -266,17 +325,72 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, Xmatrix, y, features)
266325
rng=m.rng
267326
)
268327

269-
cache = nothing
328+
cache = deepcopy(m)
329+
report = (features=features,)
330+
331+
return forest, cache, report
332+
end
333+
334+
function MMI.update(
335+
model::RandomForestRegressor,
336+
verbosity::Int,
337+
old_forest,
338+
old_model,
339+
Xmatrix,
340+
y,
341+
features,
342+
)
343+
344+
only_iterations_have_changed = MMI.is_same_except(model, old_model, :n_trees)
345+
346+
if !only_iterations_have_changed
347+
verbosity > 0 && @info info_recomputing(model.n_trees)
348+
return MMI.fit(
349+
model,
350+
verbosity,
351+
Xmatrix,
352+
y,
353+
features,
354+
)
355+
end
356+
357+
Δn_trees = model.n_trees - old_model.n_trees
358+
359+
# if `n_trees` drops, then tuncate, otherwise compute more trees
360+
if Δn_trees < 0
361+
verbosity > 0 && @info info_dropping(-Δn_trees)
362+
forest = old_forest[1:model.n_trees]
363+
else
364+
verbosity > 0 && @info info_adding(Δn_trees)
365+
forest = DT.build_forest(
366+
old_forest,
367+
y,
368+
Xmatrix,
369+
model.n_subfeatures,
370+
model.n_trees,
371+
model.sampling_fraction,
372+
model.max_depth,
373+
model.min_samples_leaf,
374+
model.min_samples_split,
375+
model.min_purity_increase;
376+
rng=model.rng
377+
)
378+
end
379+
380+
cache = deepcopy(model)
270381
report = (features=features,)
271382

272383
return forest, cache, report
384+
273385
end
274386

275387
MMI.fitted_params(::RandomForestRegressor, forest) = (forest=forest,)
276388

277389
MMI.predict(::RandomForestRegressor, forest, Xnew) = DT.apply_forest(forest, Xnew)
278390

279391
MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true
392+
MMI.iteration_parameter(::Type{<:RandomForestRegressor}) = :n_trees
393+
280394

281395
# # ALIASES FOR TYPE UNIONS
282396

test/runtests.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,46 @@ end
271271
end
272272
end
273273

274+
@testset "warm restart" begin
275+
for M in [:RandomForestClassifier, :RandomForestRegressor]
276+
data = (M == :RandomForestClassifier ? make_blobs() : make_regression())
277+
quote
278+
model = $M(n_trees=4, rng = stable_rng()) # model with 4 trees
279+
@test MLJBase.iteration_parameter(model) === :n_trees
280+
mach = machine(model, $data...)
281+
fit!(mach, verbosity=0)
282+
forest1_4 = fitted_params(mach).forest
283+
@test length(forest1_4) ==4
284+
285+
# increase n_trees:
286+
mach.model = $M(n_trees=7, rng = stable_rng())
287+
@test_logs(
288+
(:info, r""),
289+
(:info, MLJDecisionTreeInterface.info_adding(3)),
290+
fit!(mach, verbosity=1),
291+
)
292+
293+
# decrease n_trees:
294+
mach.model = $M(n_trees=5, rng = stable_rng())
295+
@test_logs(
296+
(:info, r""),
297+
(:info, MLJDecisionTreeInterface.info_dropping(2)),
298+
fit!(mach, verbosity=1),
299+
)
300+
forest1_5 = fitted_params(mach).forest
301+
@test length(forest1_5) == 5
302+
303+
# change a different hyperparameter:
304+
mach.model = $M(n_trees=5, rng = stable_rng(), max_depth=1)
305+
@test_logs(
306+
(:info, r""),
307+
(:info, MLJDecisionTreeInterface.info_recomputing(5)),
308+
fit!(mach, verbosity=1),
309+
)
310+
forest1_5_again = fitted_params(mach).forest
311+
@test length(forest1_5_again) == 5
312+
end |> eval
313+
end
314+
end
315+
274316
true

0 commit comments

Comments
 (0)