@@ -41,15 +41,11 @@ eval(:(const RFE{M} =
4141
4242# Common keyword constructor for both model types
4343"""
44- RecursiveFeatureElimination(model, n_features, step)
44+ RecursiveFeatureElimination(model; n_features=0 , step=1 )
4545
4646This model implements a recursive feature elimination algorithm for feature selection.
4747It recursively removes features, training a base model on the remaining features and
4848evaluating their importance until the desired number of features is selected.
49-
50- Construct an instance with default hyper-parameters using the syntax
51- `rfe_model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
52- hyper-parameter defaults.
5349
5450# Training data
5551In MLJ or MLJBase, bind an instance `rfe_model` to data with
@@ -92,12 +88,11 @@ Train the machine using `fit!(mach, rows=...)`.
9288# Operations
9389
9490- `transform(mach, X)`: transform the input table `X` into a new table containing only
95- columns corresponding to features gotten from the RFE algorithm.
91+ columns corresponding to features accepted by the RFE algorithm.
9692
9793- `predict(mach, X)`: transform the input table `X` into a new table same as in
98-
99- - `transform(mach, X)` above and predict using the fitted base model on the
100- transformed table.
94+ `transform(mach, X)` above and predict using the fitted base model on the transformed
95+ table.
10196
10297# Fitted parameters
10398The fields of `fitted_params(mach)` are:
@@ -108,37 +103,43 @@ The fields of `fitted_params(mach)` are:
108103# Report
109104The fields of `report(mach)` are:
110105- `scores`: dictionary of scores for each feature in the training dataset.
111- The model deems highly scored variables more significant.
106+ The model deems highly scored variables more significant.
112107
113108- `model_report`: report for the fitted base model.
114109
115110
116111# Examples
112+
113+ The following example assumes you have MLJDecisionTreeInterface in the active package
114+ ennvironment.
115+
117116```
118- using FeatureSelection, MLJ, StableRNGs
117+ using MLJ
119118
120119RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
121120
122121# Creates a dataset where the target only depends on the first 5 columns of the input table.
123- A = rand(rng, 50, 10);
122+ A = rand(50, 10);
124123y = 10 .* sin.(
125124 pi .* A[:, 1] .* A[:, 2]
126- ) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]) ;
125+ ) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5];
127126X = MLJ.table(A);
128127
129- # fit a rfe model
128+ # fit a rfe model:
130129rf = RandomForestRegressor()
131- selector = RecursiveFeatureElimination(model = rf )
130+ selector = RecursiveFeatureElimination(rf, n_features=2 )
132131mach = machine(selector, X, y)
133132fit!(mach)
134133
135134# view the feature importances
136135feature_importances(mach)
137136
138- # predict using the base model
139- Xnew = MLJ.table(rand(rng, 50, 10));
137+ # predict using the base model trained on the reduced feature set:
138+ Xnew = MLJ.table(rand(50, 10));
140139predict(mach, Xnew)
141140
141+ # transform data with all features to the reduced feature set:
142+ transform(mach, Xnew)
142143```
143144"""
144145function RecursiveFeatureElimination (
@@ -173,7 +174,7 @@ function RecursiveFeatureElimination(
173174 # This branch is hit just incase there are any models that supports_class_weights
174175 # feature importance that aren't `<:Probabilistic` or `<:Deterministic`
175176 # which is rare.
176- throw (ERR_MODEL_TYPE)
177+ throw (ERR_MODEL_TYPE)
177178 end
178179 message = MMI. clean! (selector)
179180 isempty (message) || @warn (message)
@@ -214,19 +215,19 @@ abs_last(x::Pair{<:Any, <:Real}) = abs(last(x))
214215"""
215216 score_features!(scores_dict, features, importances, n_features_to_score)
216217
217- Internal method that updates the `scores_dict` by increasing the score for each feature based on their
218+ Internal method that updates the `scores_dict` by increasing the score for each feature based on their
218219importance and store the features in the `features` array.
219220
220221# Arguments
221- - `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
222+ - `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
222223 the values are their corresponding scores.
223224- `features::Vector{Symbol}`: An array to store the top features based on importance.
224- - `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
225- contains a feature and its importance score.
225+ - `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
226+ contains a feature and its importance score.
226227- `n_features_to_score::Int`: The number of top features to score and store.
227228
228229# Notes
229- Ensure that `n_features_to_score` is less than or equal to the minimum of the
230+ Ensure that `n_features_to_score` is less than or equal to the minimum of the
230231lengths of `features` and `importances`.
231232
232233# Example
@@ -244,7 +245,7 @@ features == [:feature1, :feature2, :x1]
244245function score_features! (scores_dict, features, importances, n_features_to_score)
245246 for i in Base. OneTo (n_features_to_score)
246247 ftr = first (importances[i])
247- features[i] = ftr
248+ features[i] = ftr
248249 scores_dict[ftr] += 1
249250 end
250251end
@@ -273,7 +274,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
273274 " n_features > number of features in training data, " *
274275 " hence no feature will be eliminated."
275276 )
276- end
277+ end
277278 end
278279
279280 _step = selector. step
@@ -296,17 +297,17 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
296297 verbosity > 0 && @info (" Fitting estimator with $(n_features_to_keep) features." )
297298 data = MMI. reformat (model, MMI. selectcols (X, features_left), args... )
298299 fitresult, _, report = MMI. fit (model, verbosity - 1 , data... )
299- # Note that the MLJ feature importance API does not impose any restrictions on the
300- # ordering of `feature => score` pairs in the `importances` vector.
300+ # Note that the MLJ feature importance API does not impose any restrictions on the
301+ # ordering of `feature => score` pairs in the `importances` vector.
301302 # Therefore, the order of `feature => score` pairs in the `importances` vector
302- # might differ from the order of features in the `features` vector, which is
303+ # might differ from the order of features in the `features` vector, which is
303304 # extracted from the feature matrix `X` above. Hence the need for a dictionary
304305 # implementation.
305306 importances = MMI. feature_importances (
306307 selector. model,
307308 fitresult,
308309 report
309- )
310+ )
310311
311312 # Eliminate the worse features and increase score of remaining features
312313 sort! (importances, by= abs_last, rev = true )
0 commit comments