@@ -23,7 +23,7 @@ for (ModelType, ModelSuperType) in MODELTYPE_GIVEN_SUPERTYPES
23
23
ex = quote
24
24
mutable struct $ ModelType{M<: Supervised } <: $ModelSuperType
25
25
model:: M
26
- n_features_to_select :: Float64
26
+ n_features :: Float64
27
27
step:: Float64
28
28
end
29
29
end
@@ -34,7 +34,7 @@ eval(:(const RFE{M} = Union{$((Expr(:curly, modeltype, :M) for modeltype in MODE
34
34
35
35
# Common keyword constructor for both model types
36
36
"""
37
- RecursiveFeatureElimination(model, n_features_to_select , step)
37
+ RecursiveFeatureElimination(model, n_features , step)
38
38
39
39
This model implements a recursive feature elimination algorithm for feature selection.
40
40
It recursively removes features, training a base model on the remaining features and
@@ -73,7 +73,7 @@ Train the machine using `fit!(mach, rows=...)`.
73
73
- model: A base model with a `fit` method that provides information on feature
74
74
feature importance (i.e `reports_feature_importances(model) == true`)
75
75
76
- - n_features_to_select ::Real = 0: The number of features to select. If `0`, half of the
76
+ - n_features ::Real = 0: The number of features to select. If `0`, half of the
77
77
features are selected. If a positive integer, the parameter is the absolute number
78
78
of features to select. If a real number between 0 and 1, it is the fraction of features
79
79
to select.
@@ -136,7 +136,7 @@ predict(mach, Xnew)
136
136
function RecursiveFeatureElimination (
137
137
args... ;
138
138
model= nothing ,
139
- n_features_to_select :: Real = 0 ,
139
+ n_features :: Real = 0 ,
140
140
step:: Real = 1
141
141
)
142
142
# user can specify model as argument instead of kwarg:
@@ -155,11 +155,11 @@ function RecursiveFeatureElimination(
155
155
MMI. reports_feature_importances (model) || throw (ERR_FEATURE_IMPORTANCE_SUPPORT)
156
156
if model isa Deterministic
157
157
selector = DeterministicRecursiveFeatureElimination {typeof(model)} (
158
- model, Float64 (n_features_to_select ), Float64 (step)
158
+ model, Float64 (n_features ), Float64 (step)
159
159
)
160
160
elseif model isa Probabilistic
161
161
selector = ProbabilisticRecursiveFeatureElimination {typeof(model)} (
162
- model, Float64 (n_features_to_select ), Float64 (step)
162
+ model, Float64 (n_features ), Float64 (step)
163
163
)
164
164
else
165
165
throw (ERR_MODEL_TYPE)
@@ -176,9 +176,9 @@ function MMI.clean!(selector::RFE)
176
176
" Resetting `step = 1`"
177
177
end
178
178
179
- if selector. n_features_to_select < 0
179
+ if selector. n_features < 0
180
180
msg *= " specified `step` must be non-negative.\n " *
181
- " Resetting `n_features_to_select = 0`"
181
+ " Resetting `n_features = 0`"
182
182
end
183
183
184
184
return msg
@@ -192,14 +192,14 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
192
192
nfeatures < 2 && throw (ArgumentError (" The number of features in the feature matrix must be at least 2." ))
193
193
194
194
# Compute required number of features to select
195
- n_features_to_select = selector. n_features_to_select # Remember to modify this estimate later
195
+ n_features = selector. n_features # Remember to modify this estimate later
196
196
# # zero indicates that half of the features be selected.
197
- if n_features_to_select == 0
198
- n_features_to_select = div (nfeatures, 2 )
199
- elseif 0 < n_features_to_select < 1
200
- n_features_to_select = round (Int, n_features * n_features_to_select )
197
+ if n_features == 0
198
+ n_features = div (nfeatures, 2 )
199
+ elseif 0 < n_features < 1
200
+ n_features = round (Int, n_features * n_features )
201
201
else
202
- n_features_to_select = round (Int, n_features_to_select )
202
+ n_features = round (Int, n_features )
203
203
end
204
204
205
205
step = selector. step
@@ -216,7 +216,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
216
216
217
217
# Elimination
218
218
features_left = copy (features)
219
- while sum (support) > n_features_to_select
219
+ while sum (support) > n_features
220
220
# Rank the remaining features
221
221
model = selector. model
222
222
verbosity > 0 && @info (" Fitting estimator with $(sum (support)) features." )
@@ -239,7 +239,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
239
239
ranks = sortperm (importances)
240
240
241
241
# Eliminate the worse features
242
- threshold = min (step, sum (support) - n_features_to_select )
242
+ threshold = min (step, sum (support) - n_features )
243
243
244
244
support[indexes[ranks][1 : threshold]] .= false
245
245
ranking[.! support] .+ = 1
0 commit comments