@@ -129,13 +129,72 @@ function MMI.fit(
129129 m. min_samples_split,
130130 m. min_purity_increase;
131131 rng= m. rng)
132- cache = nothing
132+ cache = deepcopy (m)
133133
134134 report = (features= features,)
135135
136136 return (forest, classes_seen, integers_seen), cache, report
137137end
138138
139+ info_recomputing (n) = " Detected a change to hyperparameters " *
140+ " not restricted to `n_trees`. Recomputing all $n trees trees. "
141+ info_dropping (n) = " Dropping $n trees from the ensemble. "
142+ info_adding (n) = " Adding $n trees to the ensemble. "
143+
144+ function MMI. update (
145+ model:: RandomForestClassifier ,
146+ verbosity:: Int ,
147+ old_fitresult,
148+ old_model,
149+ Xmatrix,
150+ yplain,
151+ features,
152+ classes,
153+ )
154+
155+ only_iterations_have_changed = MMI. is_same_except (model, old_model, :n_trees )
156+
157+ if ! only_iterations_have_changed
158+ verbosity > 0 && @info info_recomputing (model. n_trees)
159+ return MMI. fit (
160+ model,
161+ verbosity,
162+ Xmatrix,
163+ yplain,
164+ features,
165+ classes,
166+ )
167+ end
168+
169+ old_forest = old_fitresult[1 ]
170+ Δn_trees = model. n_trees - old_model. n_trees
171+ # if `n_trees` drops, then tuncate, otherwise compute more trees
172+ if Δn_trees < 0
173+ verbosity > 0 && @info info_dropping (- Δn_trees)
174+ forest = old_forest[1 : model. n_trees]
175+ else
176+ verbosity > 0 && @info info_adding (Δn_trees)
177+ forest = DT. build_forest (
178+ old_forest,
179+ yplain, Xmatrix,
180+ model. n_subfeatures,
181+ model. n_trees,
182+ model. sampling_fraction,
183+ model. max_depth,
184+ model. min_samples_leaf,
185+ model. min_samples_split,
186+ model. min_purity_increase;
187+ rng= model. rng,
188+ )
189+ end
190+
191+ fitresult = (forest, old_fitresult[2 : 3 ]. .. )
192+ cache = deepcopy (model)
193+ report = (features= features,)
194+ return fitresult, cache, report
195+
196+ end
197+
139198MMI. fitted_params (:: RandomForestClassifier , (forest,_)) = (forest= forest,)
140199
141200function MMI. predict (m:: RandomForestClassifier , fitresult, Xnew)
@@ -145,7 +204,7 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
145204end
146205
147206MMI. reports_feature_importances (:: Type{<:RandomForestClassifier} ) = true
148-
207+ MMI . iteration_parameter ( :: Type{<:RandomForestClassifier} ) = :n_trees
149208
150209# # ADA BOOST STUMP CLASSIFIER
151210
@@ -266,17 +325,72 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, Xmatrix, y, features)
266325 rng= m. rng
267326 )
268327
269- cache = nothing
328+ cache = deepcopy (m)
329+ report = (features= features,)
330+
331+ return forest, cache, report
332+ end
333+
334+ function MMI. update (
335+ model:: RandomForestRegressor ,
336+ verbosity:: Int ,
337+ old_forest,
338+ old_model,
339+ Xmatrix,
340+ y,
341+ features,
342+ )
343+
344+ only_iterations_have_changed = MMI. is_same_except (model, old_model, :n_trees )
345+
346+ if ! only_iterations_have_changed
347+ verbosity > 0 && @info info_recomputing (model. n_trees)
348+ return MMI. fit (
349+ model,
350+ verbosity,
351+ Xmatrix,
352+ y,
353+ features,
354+ )
355+ end
356+
357+ Δn_trees = model. n_trees - old_model. n_trees
358+
359+ # if `n_trees` drops, then tuncate, otherwise compute more trees
360+ if Δn_trees < 0
361+ verbosity > 0 && @info info_dropping (- Δn_trees)
362+ forest = old_forest[1 : model. n_trees]
363+ else
364+ verbosity > 0 && @info info_adding (Δn_trees)
365+ forest = DT. build_forest (
366+ old_forest,
367+ y,
368+ Xmatrix,
369+ model. n_subfeatures,
370+ model. n_trees,
371+ model. sampling_fraction,
372+ model. max_depth,
373+ model. min_samples_leaf,
374+ model. min_samples_split,
375+ model. min_purity_increase;
376+ rng= model. rng
377+ )
378+ end
379+
380+ cache = deepcopy (model)
270381 report = (features= features,)
271382
272383 return forest, cache, report
384+
273385end
274386
275387MMI. fitted_params (:: RandomForestRegressor , forest) = (forest= forest,)
276388
277389MMI. predict (:: RandomForestRegressor , forest, Xnew) = DT. apply_forest (forest, Xnew)
278390
279391MMI. reports_feature_importances (:: Type{<:RandomForestRegressor} ) = true
392+ MMI. iteration_parameter (:: Type{<:RandomForestRegressor} ) = :n_trees
393+
280394
281395# # ALIASES FOR TYPE UNIONS
282396
0 commit comments