@@ -129,13 +129,66 @@ function MMI.fit(
129129 m. min_samples_split,
130130 m. min_purity_increase;
131131 rng= m. rng)
132- cache = nothing
132+ cache = deepcopy (m)
133133
134134 report = (features= features,)
135135
136136 return (forest, classes_seen, integers_seen), cache, report
137137end
138138
139+ function MMI. update (
140+ model:: RandomForestClassifier ,
141+ verbosity:: Int ,
142+ old_fitresult,
143+ old_model,
144+ Xmatrix,
145+ yplain,
146+ features,
147+ classes,
148+ )
149+
150+ only_iterations_have_changed = MMI. is_same_except (model, old_model, :n_trees )
151+
152+ if ! only_iterations_have_changed
153+ return MMI. fit (
154+ model,
155+ verbosity,
156+ Xmatrix,
157+ yplain,
158+ features,
159+ classes,
160+ )
161+ end
162+
163+ old_forest = old_fitresult[1 ]
164+ Δn_trees = model. n_trees - old_model. n_trees
165+ # if `n_trees` drops, then tuncate, otherwise compute more trees
166+ if Δn_trees < 0
167+ verbosity > 0 && @info " Dropping $(- Δn_trees) trees from the forest. "
168+ forest = old_forest[1 : model. n_trees]
169+ else
170+ verbosity > 0 && @info " Adding $Δn_trees trees to the forest. "
171+ forest = DT. build_forest (
172+ old_forest,
173+ yplain, Xmatrix,
174+ model. n_subfeatures,
175+ model. n_trees,
176+ model. sampling_fraction,
177+ model. max_depth,
178+ model. min_samples_leaf,
179+ model. min_samples_split,
180+ model. min_purity_increase;
181+ rng= model. rng,
182+ )
183+ end
184+
185+ fitresult = (forest, old_fitresult[2 : 3 ]. .. )
186+ cache = deepcopy (model)
187+ report = (features= features,)
188+ return fitresult, cache, report
189+
190+ end
191+
139192MMI. fitted_params (:: RandomForestClassifier , (forest,_)) = (forest= forest,)
140193
141194function MMI. predict (m:: RandomForestClassifier , fitresult, Xnew)
@@ -145,7 +198,7 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
145198end
146199
147200MMI. reports_feature_importances (:: Type{<:RandomForestClassifier} ) = true
148-
201+ MMI . iteration_parameter ( :: Type{<:RandomForestClassifier} ) = :n_trees
149202
150203# # ADA BOOST STUMP CLASSIFIER
151204
@@ -266,17 +319,71 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, Xmatrix, y, features)
266319 rng= m. rng
267320 )
268321
269- cache = nothing
322+ cache = deepcopy (m)
323+ report = (features= features,)
324+
325+ return forest, cache, report
326+ end
327+
328+ function MMI. update (
329+ model:: RandomForestRegressor ,
330+ verbosity:: Int ,
331+ old_forest,
332+ old_model,
333+ Xmatrix,
334+ y,
335+ features,
336+ )
337+
338+ only_iterations_have_changed = MMI. is_same_except (model, old_model, :n_trees )
339+
340+ if ! only_iterations_have_changed
341+ return MMI. fit (
342+ model,
343+ verbosity,
344+ Xmatrix,
345+ y,
346+ features,
347+ )
348+ end
349+
350+ Δn_trees = model. n_trees - old_model. n_trees
351+
352+ # if `n_trees` drops, then tuncate, otherwise compute more trees
353+ if Δn_trees < 0
354+ verbosity > 0 && @info " Dropping $(- Δn_trees) trees from the forest. "
355+ forest = old_forest[1 : model. n_trees]
356+ else
357+ verbosity > 0 && @info " Adding $Δn_trees trees to the forest. "
358+ forest = DT. build_forest (
359+ old_forest,
360+ y,
361+ Xmatrix,
362+ model. n_subfeatures,
363+ model. n_trees,
364+ model. sampling_fraction,
365+ model. max_depth,
366+ model. min_samples_leaf,
367+ model. min_samples_split,
368+ model. min_purity_increase;
369+ rng= model. rng
370+ )
371+ end
372+
373+ cache = deepcopy (model)
270374 report = (features= features,)
271375
272376 return forest, cache, report
377+
273378end
274379
275380MMI. fitted_params (:: RandomForestRegressor , forest) = (forest= forest,)
276381
277382MMI. predict (:: RandomForestRegressor , forest, Xnew) = DT. apply_forest (forest, Xnew)
278383
279384MMI. reports_feature_importances (:: Type{<:RandomForestRegressor} ) = true
385+ MMI. iteration_parameter (:: Type{<:RandomForestRegressor} ) = :n_trees
386+
280387
281388# # ALIASES FOR TYPE UNIONS
282389
0 commit comments