@@ -41,17 +41,14 @@ eval(:(const RFE{M} =
41
41
42
42
# Common keyword constructor for both model types
43
43
"""
44
- RecursiveFeatureElimination(model, n_features, step)
44
+ RecursiveFeatureElimination(model; n_features=0 , step=1 )
45
45
46
46
This model implements a recursive feature elimination algorithm for feature selection.
47
47
It recursively removes features, training a base model on the remaining features and
48
48
evaluating their importance until the desired number of features is selected.
49
-
50
- Construct an instance with default hyper-parameters using the syntax
51
- `rfe_model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
52
- hyper-parameter defaults.
53
49
54
50
# Training data
51
+
55
52
In MLJ or MLJBase, bind an instance `rfe_model` to data with
56
53
57
54
mach = machine(rfe_model, X, y)
@@ -92,53 +89,62 @@ Train the machine using `fit!(mach, rows=...)`.
92
89
# Operations
93
90
94
91
- `transform(mach, X)`: transform the input table `X` into a new table containing only
95
- columns corresponding to features gotten from the RFE algorithm.
92
+ columns corresponding to features accepted by the RFE algorithm.
96
93
97
94
- `predict(mach, X)`: transform the input table `X` into a new table same as in
98
-
99
- - `transform(mach, X)` above and predict using the fitted base model on the
100
- transformed table.
95
+ `transform(mach, X)` above and predict using the fitted base model on the transformed
96
+ table.
101
97
102
98
# Fitted parameters
99
+
103
100
The fields of `fitted_params(mach)` are:
101
+
104
102
- `features_left`: names of features remaining after recursive feature elimination.
105
103
106
104
- `model_fitresult`: fitted parameters of the base model.
107
105
108
106
# Report
107
+
109
108
The fields of `report(mach)` are:
109
+
110
110
- `scores`: dictionary of scores for each feature in the training dataset.
111
- The model deems highly scored variables more significant.
111
+ The model deems highly scored variables more significant.
112
112
113
113
- `model_report`: report for the fitted base model.
114
114
115
115
116
116
# Examples
117
+
118
+ The following example assumes you have MLJDecisionTreeInterface in the active package
119
+ ennvironment.
120
+
117
121
```
118
- using FeatureSelection, MLJ, StableRNGs
122
+ using MLJ
119
123
120
124
RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
121
125
122
126
# Creates a dataset where the target only depends on the first 5 columns of the input table.
123
- A = rand(rng, 50, 10);
127
+ A = rand(50, 10);
124
128
y = 10 .* sin.(
125
129
pi .* A[:, 1] .* A[:, 2]
126
- ) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]) ;
130
+ ) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5];
127
131
X = MLJ.table(A);
128
132
129
- # fit a rfe model
133
+ # fit a rfe model:
130
134
rf = RandomForestRegressor()
131
- selector = RecursiveFeatureElimination(model = rf )
135
+ selector = RecursiveFeatureElimination(rf, n_features=2 )
132
136
mach = machine(selector, X, y)
133
137
fit!(mach)
134
138
135
139
# view the feature importances
136
140
feature_importances(mach)
137
141
138
- # predict using the base model
139
- Xnew = MLJ.table(rand(rng, 50, 10));
142
+ # predict using the base model trained on the reduced feature set:
143
+ Xnew = MLJ.table(rand(50, 10));
140
144
predict(mach, Xnew)
141
145
146
+ # transform data with all features to the reduced feature set:
147
+ transform(mach, Xnew)
142
148
```
143
149
"""
144
150
function RecursiveFeatureElimination (
@@ -173,7 +179,7 @@ function RecursiveFeatureElimination(
173
179
# This branch is hit just incase there are any models that supports_class_weights
174
180
# feature importance that aren't `<:Probabilistic` or `<:Deterministic`
175
181
# which is rare.
176
- throw (ERR_MODEL_TYPE)
182
+ throw (ERR_MODEL_TYPE)
177
183
end
178
184
message = MMI. clean! (selector)
179
185
isempty (message) || @warn (message)
@@ -214,22 +220,30 @@ abs_last(x::Pair{<:Any, <:Real}) = abs(last(x))
214
220
"""
215
221
score_features!(scores_dict, features, importances, n_features_to_score)
216
222
217
- Internal method that updates the `scores_dict` by increasing the score for each feature based on their
223
+ **Private method.**
224
+
225
+ Update the `scores_dict` by increasing the score for each feature based on their
218
226
importance and store the features in the `features` array.
219
227
220
228
# Arguments
221
- - `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
229
+
230
+ - `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
222
231
the values are their corresponding scores.
232
+
223
233
- `features::Vector{Symbol}`: An array to store the top features based on importance.
224
- - `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
225
- contains a feature and its importance score.
234
+
235
+ - `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
236
+ contains a feature and its importance score.
237
+
226
238
- `n_features_to_score::Int`: The number of top features to score and store.
227
239
228
240
# Notes
229
- Ensure that `n_features_to_score` is less than or equal to the minimum of the
241
+
242
+ Ensure that `n_features_to_score` is less than or equal to the minimum of the
230
243
lengths of `features` and `importances`.
231
244
232
245
# Example
246
+
233
247
```julia
234
248
scores_dict = Dict(:feature1 => 0, :feature2 => 0, :feature3 => 0)
235
249
features = [:x1, :x1, :x1]
@@ -244,7 +258,7 @@ features == [:feature1, :feature2, :x1]
244
258
function score_features! (scores_dict, features, importances, n_features_to_score)
245
259
for i in Base. OneTo (n_features_to_score)
246
260
ftr = first (importances[i])
247
- features[i] = ftr
261
+ features[i] = ftr
248
262
scores_dict[ftr] += 1
249
263
end
250
264
end
@@ -273,7 +287,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
273
287
" n_features > number of features in training data, " *
274
288
" hence no feature will be eliminated."
275
289
)
276
- end
290
+ end
277
291
end
278
292
279
293
_step = selector. step
@@ -296,17 +310,17 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
296
310
verbosity > 0 && @info (" Fitting estimator with $(n_features_to_keep) features." )
297
311
data = MMI. reformat (model, MMI. selectcols (X, features_left), args... )
298
312
fitresult, _, report = MMI. fit (model, verbosity - 1 , data... )
299
- # Note that the MLJ feature importance API does not impose any restrictions on the
300
- # ordering of `feature => score` pairs in the `importances` vector.
313
+ # Note that the MLJ feature importance API does not impose any restrictions on the
314
+ # ordering of `feature => score` pairs in the `importances` vector.
301
315
# Therefore, the order of `feature => score` pairs in the `importances` vector
302
- # might differ from the order of features in the `features` vector, which is
316
+ # might differ from the order of features in the `features` vector, which is
303
317
# extracted from the feature matrix `X` above. Hence the need for a dictionary
304
318
# implementation.
305
319
importances = MMI. feature_importances (
306
320
selector. model,
307
321
fitresult,
308
322
report
309
- )
323
+ )
310
324
311
325
# Eliminate the worse features and increase score of remaining features
312
326
sort! (importances, by= abs_last, rev = true )
396
410
MMI. load_path (:: Type{<:RFE} ) = " FeatureSelection.RecursiveFeatureElimination"
397
411
MMI. constructor (:: Type{<:RFE} ) = RecursiveFeatureElimination
398
412
MMI. package_name (:: Type{<:RFE} ) = " FeatureSelection"
413
+ MMI. is_wrapper (:: Type{<:RFE} ) = true
399
414
400
415
for trait in [
401
416
:supports_weights ,
0 commit comments