Skip to content

Commit 1d312c2

Browse files
committed
rename n_features_to_select to n_features
1 parent d2e41a5 commit 1d312c2

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/models/rfe.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ for (ModelType, ModelSuperType) in MODELTYPE_GIVEN_SUPERTYPES
2323
ex = quote
2424
mutable struct $ModelType{M<:Supervised} <: $ModelSuperType
2525
model::M
26-
n_features_to_select::Float64
26+
n_features::Float64
2727
step::Float64
2828
end
2929
end
@@ -34,7 +34,7 @@ eval(:(const RFE{M} = Union{$((Expr(:curly, modeltype, :M) for modeltype in MODE
3434

3535
# Common keyword constructor for both model types
3636
"""
37-
RecursiveFeatureElimination(model, n_features_to_select, step)
37+
RecursiveFeatureElimination(model, n_features, step)
3838
3939
This model implements a recursive feature elimination algorithm for feature selection.
4040
It recursively removes features, training a base model on the remaining features and
@@ -73,7 +73,7 @@ Train the machine using `fit!(mach, rows=...)`.
7373
- model: A base model with a `fit` method that provides information on feature
7474
feature importance (i.e `reports_feature_importances(model) == true`)
7575
76-
- n_features_to_select::Real = 0: The number of features to select. If `0`, half of the
76+
- n_features::Real = 0: The number of features to select. If `0`, half of the
7777
features are selected. If a positive integer, the parameter is the absolute number
7878
of features to select. If a real number between 0 and 1, it is the fraction of features
7979
to select.
@@ -136,7 +136,7 @@ predict(mach, Xnew)
136136
function RecursiveFeatureElimination(
137137
args...;
138138
model=nothing,
139-
n_features_to_select::Real=0,
139+
n_features::Real=0,
140140
step::Real = 1
141141
)
142142
# user can specify model as argument instead of kwarg:
@@ -155,11 +155,11 @@ function RecursiveFeatureElimination(
155155
MMI.reports_feature_importances(model) || throw(ERR_FEATURE_IMPORTANCE_SUPPORT)
156156
if model isa Deterministic
157157
selector = DeterministicRecursiveFeatureElimination{typeof(model)}(
158-
model, Float64(n_features_to_select), Float64(step)
158+
model, Float64(n_features), Float64(step)
159159
)
160160
elseif model isa Probabilistic
161161
selector = ProbabilisticRecursiveFeatureElimination{typeof(model)}(
162-
model, Float64(n_features_to_select), Float64(step)
162+
model, Float64(n_features), Float64(step)
163163
)
164164
else
165165
throw(ERR_MODEL_TYPE)
@@ -176,9 +176,9 @@ function MMI.clean!(selector::RFE)
176176
"Resetting `step = 1`"
177177
end
178178

179-
if selector.n_features_to_select < 0
179+
if selector.n_features < 0
180180
msg *= "specified `step` must be non-negative.\n"*
181-
"Resetting `n_features_to_select = 0`"
181+
"Resetting `n_features = 0`"
182182
end
183183

184184
return msg
@@ -192,14 +192,14 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
192192
nfeatures < 2 && throw(ArgumentError("The number of features in the feature matrix must be at least 2."))
193193

194194
# Compute required number of features to select
195-
n_features_to_select = selector.n_features_to_select # Remember to modify this estimate later
195+
n_features = selector.n_features # Remember to modify this estimate later
196196
## zero indicates that half of the features be selected.
197-
if n_features_to_select == 0
198-
n_features_to_select = div(nfeatures, 2)
199-
elseif 0 < n_features_to_select < 1
200-
n_features_to_select = round(Int, n_features * n_features_to_select)
197+
if n_features == 0
198+
n_features = div(nfeatures, 2)
199+
elseif 0 < n_features < 1
200+
n_features = round(Int, n_features * n_features)
201201
else
202-
n_features_to_select = round(Int, n_features_to_select)
202+
n_features = round(Int, n_features)
203203
end
204204

205205
step = selector.step
@@ -216,7 +216,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
216216

217217
# Elimination
218218
features_left = copy(features)
219-
while sum(support) > n_features_to_select
219+
while sum(support) > n_features
220220
# Rank the remaining features
221221
model = selector.model
222222
verbosity > 0 && @info("Fitting estimator with $(sum(support)) features.")
@@ -239,7 +239,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
239239
ranks = sortperm(importances)
240240

241241
# Eliminate the worse features
242-
threshold = min(step, sum(support) - n_features_to_select)
242+
threshold = min(step, sum(support) - n_features)
243243

244244
support[indexes[ranks][1:threshold]] .= false
245245
ranking[.!support] .+= 1

0 commit comments

Comments
 (0)