@@ -202,34 +202,35 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
202
202
)
203
203
204
204
# Compute required number of features to select
205
- n_features = selector. n_features # Remember to modify this estimate later
205
+ n_features_select = selector. n_features
206
206
# # zero indicates that half of the features be selected.
207
- if n_features == 0
208
- n_features = div (nfeatures, 2 )
209
- elseif 0 < n_features < 1
210
- n_features = round (Int, n_features * n_features )
207
+ if n_features_select == 0
208
+ n_features_select = div (nfeatures, 2 )
209
+ elseif 0 < n_features_select < 1
210
+ n_features_select = round (Int, n_features_select * nfeatures )
211
211
else
212
- n_features = round (Int, n_features )
212
+ n_features_select = round (Int, n_features_select )
213
213
end
214
214
215
215
step = selector. step
216
216
217
217
if 0 < step < 1
218
- step = round (Int, max (1 , step * n_features ))
218
+ step = round (Int, max (1 , step * n_features_select ))
219
219
else
220
220
step = round (Int, step)
221
221
end
222
222
223
223
support = trues (nfeatures)
224
- ranking = ones (nfeatures) # every feature has equal rank initially
225
- indexes = axes (support, 1 )
224
+ ranking = ones (Int, nfeatures) # every feature has equal rank initially
225
+ mask = trues (nfeatures) # for boolean indexing of ranking vector in while loop below
226
226
227
227
# Elimination
228
- features_left = copy (features)
229
- while sum (support) > n_features
228
+ features_left = features
229
+ n_features_left = length (features_left)
230
+ while n_features_left > n_features_select
230
231
# Rank the remaining features
231
232
model = selector. model
232
- verbosity > 0 && @info (" Fitting estimator with $(sum (support) ) features." )
233
+ verbosity > 0 && @info (" Fitting estimator with $(n_features_left ) features." )
233
234
234
235
data = MMI. reformat (model, MMI. selectcols (X, features_left), args... )
235
236
@@ -249,24 +250,25 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
249
250
ranks = sortperm (importances)
250
251
251
252
# Eliminate the worse features
252
- threshold = min (step, sum (support) - n_features )
253
-
254
- support[indexes[ranks][ 1 : threshold]] . = false
255
- ranking[. ! support] .+ = 1
253
+ threshold = min (step, n_features_left - n_features_select )
254
+ @views (support[support][ranks[ 1 : threshold]]) . = false
255
+ mask . = support .= = false
256
+ @views ( ranking[mask]) .+ = 1
256
257
257
258
# Remaining features
258
- features_left = @view (features[support])
259
+ features_left = features[support]
260
+ n_features_left = length (features_left)
259
261
end
260
262
261
263
# Set final attributes
262
264
data = MMI. reformat (selector. model, MMI. selectcols (X, features_left), args... )
263
- verbosity > 0 && @info (" Fitting estimator with $(sum (support) ) features." )
265
+ verbosity > 0 && @info (" Fitting estimator with $(n_features_left ) features." )
264
266
model_fitresult, _, model_report = MMI. fit (selector. model, verbosity - 1 , data... )
265
267
266
268
fitresult = (
267
269
support = support,
268
270
model_fitresult = model_fitresult,
269
- features_left = copy ( features_left) ,
271
+ features_left = features_left,
270
272
features = features
271
273
)
272
274
report = (
280
282
281
283
function MMI. fitted_params (model:: RFE , fitresult)
282
284
(
283
- features_left = fitresult. features_left,
285
+ features_left = copy ( fitresult. features_left) ,
284
286
model_fitresult = MMI. fitted_params (model. model, fitresult. model_fitresult)
285
287
)
286
288
end
@@ -295,15 +297,45 @@ function MMI.transform(::RFE, fitresult, X)
295
297
sch = Tables. schema (Tables. columns (X))
296
298
if (length (fitresult. features) == length (sch. names) &&
297
299
! all (e -> e in sch. names, fitresult. features))
298
- throw (
299
- ERR_FEATURES_SEEN
300
- )
300
+ throw (
301
+ ERR_FEATURES_SEEN
302
+ )
301
303
end
302
304
return MMI. selectcols (X, fitresult. features_left)
303
305
end
304
306
305
307
function MMI. feature_importances (:: RFE , fitresult, report)
306
- return Pair .(fitresult. features, report. ranking)
308
+ return Pair .(fitresult. features, Iterators. reverse (report. ranking))
309
+ end
310
+
311
+ function MMI. save (model:: RFE , fitresult)
312
+ support = fitresult. support
313
+ atomic_fitresult = fitresult. model_fitresult
314
+ features_left = fitresult. features_left
315
+ features = fitresult. features
316
+
317
+ atom = model. model
318
+ return (
319
+ support = copy (support),
320
+ model_fitresult = MMI. save (atom, atomic_fitresult),
321
+ features_left = copy (features_left),
322
+ features = copy (features)
323
+ )
324
+ end
325
+
326
+ function MMI. restore (model:: RFE , serializable_fitresult)
327
+ support = serializable_fitresult. support
328
+ atomic_serializable_fitresult = serializable_fitresult. model_fitresult
329
+ features_left = serializable_fitresult. features_left
330
+ features = serializable_fitresult. features
331
+
332
+ atom = model. model
333
+ return (
334
+ support = support,
335
+ model_fitresult = MMI. restore (atom, atomic_serializable_fitresult),
336
+ features_left = features_left,
337
+ features = features
338
+ )
307
339
end
308
340
309
341
# # Traits definitions
0 commit comments