Skip to content

Commit 8323234

Browse files
committed
implement warm restart for random forest models (x 2)
1 parent 39cdbbe commit 8323234

File tree

2 files changed

+138
-3
lines changed

2 files changed

+138
-3
lines changed

src/MLJDecisionTreeInterface.jl

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,66 @@ function MMI.fit(
129129
m.min_samples_split,
130130
m.min_purity_increase;
131131
rng=m.rng)
132-
cache = nothing
132+
cache = deepcopy(m)
133133

134134
report = (features=features,)
135135

136136
return (forest, classes_seen, integers_seen), cache, report
137137
end
138138

139+
function MMI.update(
140+
model::RandomForestClassifier,
141+
verbosity::Int,
142+
old_fitresult,
143+
old_model,
144+
Xmatrix,
145+
yplain,
146+
features,
147+
classes,
148+
)
149+
150+
only_iterations_have_changed = MMI.is_same_except(model, old_model, :n_trees)
151+
152+
if !only_iterations_have_changed
153+
return MMI.fit(
154+
model,
155+
verbosity,
156+
Xmatrix,
157+
yplain,
158+
features,
159+
classes,
160+
)
161+
end
162+
163+
old_forest = old_fitresult[1]
164+
Δn_trees = model.n_trees - old_model.n_trees
165+
# if `n_trees` drops, then tuncate, otherwise compute more trees
166+
if Δn_trees < 0
167+
verbosity > 0 && @info "Dropping $(-Δn_trees) trees from the forest. "
168+
forest = old_forest[1:model.n_trees]
169+
else
170+
verbosity > 0 && @info "Adding $Δn_trees trees to the forest. "
171+
forest = DT.build_forest(
172+
old_forest,
173+
yplain, Xmatrix,
174+
model.n_subfeatures,
175+
model.n_trees,
176+
model.sampling_fraction,
177+
model.max_depth,
178+
model.min_samples_leaf,
179+
model.min_samples_split,
180+
model.min_purity_increase;
181+
rng=model.rng,
182+
)
183+
end
184+
185+
fitresult = (forest, old_fitresult[2:3]...)
186+
cache = deepcopy(model)
187+
report = (features=features,)
188+
return fitresult, cache, report
189+
190+
end
191+
139192
MMI.fitted_params(::RandomForestClassifier, (forest,_)) = (forest=forest,)
140193

141194
function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
@@ -145,7 +198,7 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
145198
end
146199

147200
MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true
148-
201+
MMI.iteration_parameter(::Type{<:RandomForestClassifier}) = :n_trees
149202

150203
# # ADA BOOST STUMP CLASSIFIER
151204

@@ -266,17 +319,71 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, Xmatrix, y, features)
266319
rng=m.rng
267320
)
268321

269-
cache = nothing
322+
cache = deepcopy(m)
323+
report = (features=features,)
324+
325+
return forest, cache, report
326+
end
327+
328+
function MMI.update(
329+
model::RandomForestRegressor,
330+
verbosity::Int,
331+
old_forest,
332+
old_model,
333+
Xmatrix,
334+
y,
335+
features,
336+
)
337+
338+
only_iterations_have_changed = MMI.is_same_except(model, old_model, :n_trees)
339+
340+
if !only_iterations_have_changed
341+
return MMI.fit(
342+
model,
343+
verbosity,
344+
Xmatrix,
345+
y,
346+
features,
347+
)
348+
end
349+
350+
Δn_trees = model.n_trees - old_model.n_trees
351+
352+
# if `n_trees` drops, then tuncate, otherwise compute more trees
353+
if Δn_trees < 0
354+
verbosity > 0 && @info "Dropping $(-Δn_trees) trees from the forest. "
355+
forest = old_forest[1:model.n_trees]
356+
else
357+
verbosity > 0 && @info "Adding $Δn_trees trees to the forest. "
358+
forest = DT.build_forest(
359+
old_forest,
360+
y,
361+
Xmatrix,
362+
model.n_subfeatures,
363+
model.n_trees,
364+
model.sampling_fraction,
365+
model.max_depth,
366+
model.min_samples_leaf,
367+
model.min_samples_split,
368+
model.min_purity_increase;
369+
rng=model.rng
370+
)
371+
end
372+
373+
cache = deepcopy(model)
270374
report = (features=features,)
271375

272376
return forest, cache, report
377+
273378
end
274379

275380
MMI.fitted_params(::RandomForestRegressor, forest) = (forest=forest,)
276381

277382
MMI.predict(::RandomForestRegressor, forest, Xnew) = DT.apply_forest(forest, Xnew)
278383

279384
MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true
385+
MMI.iteration_parameter(::Type{<:RandomForestRegressor}) = :n_trees
386+
280387

281388
# # ALIASES FOR TYPE UNIONS
282389

test/runtests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,32 @@ end
271271
end
272272
end
273273

274+
@testset "warm restart" begin
275+
for M in [:RandomForestClassifier, :RandomForestRegressor]
276+
data = (M == :RandomForestClassifier ? make_blobs() : make_regression())
277+
quote
278+
model = $M(n_trees=4, rng = stable_rng()) # model with 4 trees
279+
@test MLJBase.iteration_parameter(model) === :n_trees
280+
mach = machine(model, $data...)
281+
fit!(mach, verbosity=0)
282+
forest1_4 = fitted_params(mach).forest
283+
@test length(forest1_4) ==4
284+
mach.model = $M(n_trees=7, rng = stable_rng())
285+
@test_logs(
286+
(:info, r""),
287+
(:info, r"Adding 3 trees"),
288+
fit!(mach, verbosity=1),
289+
)
290+
mach.model = $M(n_trees=5, rng = stable_rng())
291+
@test_logs(
292+
(:info, r""),
293+
(:info, r"Dropping 2 trees"),
294+
fit!(mach, verbosity=1),
295+
)
296+
forest1_5 = fitted_params(mach).forest
297+
@test length(forest1_5) == 5
298+
end |> eval
299+
end
300+
end
301+
274302
true

0 commit comments

Comments
 (0)