Skip to content

Commit d97b0d0

Browse files
solegalliSunnyxBd
andauthored
Recursive feature addition, closes #135 (#188)
* add new file for recursive feature addition * add feature selection code * add recursive feature addition to init * add unit tests * fix code style * delete comment * add selected_features attribute to list of attributes * delete white space * delete unecessary baseline model * finish docstring and fix performance drift dict Co-authored-by: SunnyxBd <[email protected]>
1 parent 7703e58 commit d97b0d0

File tree

3 files changed

+392
-0
lines changed

3 files changed

+392
-0
lines changed

feature_engine/selection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .drop_correlated_features import DropCorrelatedFeatures
88
from .shuffle_features import SelectByShuffling
99
from .single_feature_performance import SelectBySingleFeaturePerformance
10+
from .recursive_feature_addition import RecursiveFeatureAddition
1011
from .recursive_feature_elimination import RecursiveFeatureElimination
1112
from .target_mean_selection import SelectByTargetMeanPerformance
1213

@@ -17,6 +18,7 @@
1718
"DropCorrelatedFeatures",
1819
"SelectByShuffling",
1920
"SelectBySingleFeaturePerformance",
21+
"RecursiveFeatureAddition",
2022
"RecursiveFeatureElimination",
2123
"SelectByTargetMeanPerformance",
2224
]
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from typing import List, Union
2+
3+
import pandas as pd
4+
from sklearn.base import BaseEstimator, TransformerMixin
5+
from sklearn.ensemble import RandomForestClassifier
6+
from sklearn.model_selection import cross_validate
7+
from sklearn.utils.validation import check_is_fitted
8+
9+
from feature_engine.dataframe_checks import (
10+
_is_dataframe,
11+
_check_input_matches_training_df,
12+
)
13+
from feature_engine.selection.base_selector import get_feature_importances
14+
from feature_engine.variable_manipulation import (
15+
_check_input_parameter_variables,
16+
_find_or_check_numerical_variables,
17+
)
18+
19+
Variables = Union[None, int, str, List[Union[str, int]]]
20+
21+
22+
class RecursiveFeatureAddition(BaseEstimator, TransformerMixin):
23+
"""
24+
RecursiveFeatureAddition selects features following a recursive process.
25+
26+
The process is as follows:
27+
28+
1. Train an estimator using all the features.
29+
30+
2. Rank the features according to their importance, derived from the estimator.
31+
32+
3. Train an estimator with the most important feature and determine its performance.
33+
34+
4. Add the second most important feature and train a new estimator.
35+
36+
5. Calculate the difference in performance between the last estimator and the
37+
previous one.
38+
39+
6. If the performance increases beyond the threshold, then that feature is important
40+
and will be kept. Otherwise, that feature is removed
41+
.
42+
7. Repeat steps 4-6 until all features have been evaluated.
43+
44+
Model training and performance calculation are done with cross-validation.
45+
46+
Parameters
47+
----------
48+
variables : str or list, default=None
49+
The list of variable to be evaluated. If None, the transformer will evaluate
50+
all numerical features in the dataset.
51+
52+
estimator : object, default = RandomForestClassifier()
53+
A Scikit-learn estimator for regression or classification.
54+
The estimator must have either a `feature_importances` or `coef_` attribute
55+
after fitting.
56+
57+
scoring : str, default='roc_auc'
58+
Desired metric to optimise the performance of the estimator. Comes from
59+
sklearn.metrics. See the model evaluation documentation for more options:
60+
https://scikit-learn.org/stable/modules/model_evaluation.html
61+
62+
threshold : float, int, default = 0.01
63+
The value that defines if a feature will be kept or removed. Note that for
64+
metrics like roc-auc, r2_score and accuracy, the thresholds will be floats
65+
between 0 and 1. For metrics like the mean_square_error and the
66+
root_mean_square_error the threshold will be a big number.
67+
The threshold must be defined by the user. Bigger thresholds will select less
68+
features.
69+
70+
cv : int, default=3
71+
Cross-validation fold to be used to fit the estimator.
72+
73+
Attributes
74+
----------
75+
initial_model_performance_ :
76+
Performance of the model trained using the original dataset.
77+
78+
feature_importances_ :
79+
Pandas Series with the feature importance.
80+
81+
performance_drifts_:
82+
Dictionary with the performance drift per removed feature.
83+
84+
selected_features_:
85+
List with the selected features.
86+
87+
Methods
88+
-------
89+
fit:
90+
Find the important features.
91+
transform:
92+
Reduce X to the selected features.
93+
fit_transform:
94+
Fit to data, then transform it.
95+
"""
96+
97+
def __init__(
98+
self,
99+
estimator=RandomForestClassifier(),
100+
scoring: str = "roc_auc",
101+
cv: int = 3,
102+
threshold: Union[int, float] = 0.01,
103+
variables: Variables = None,
104+
):
105+
106+
if not isinstance(cv, int) or cv < 1:
107+
raise ValueError("cv can only take positive integers bigger than 1")
108+
109+
if not isinstance(threshold, (int, float)):
110+
raise ValueError("threshold can only be integer or float")
111+
112+
self.variables = _check_input_parameter_variables(variables)
113+
self.estimator = estimator
114+
self.scoring = scoring
115+
self.threshold = threshold
116+
self.cv = cv
117+
118+
def fit(self, X: pd.DataFrame, y: pd.Series):
119+
"""
120+
Find the important features. Note that the selector trains various models at
121+
each round of selection, so it might take a while.
122+
123+
Parameters
124+
----------
125+
X : pandas dataframe of shape = [n_samples, n_features]
126+
The input dataframe
127+
128+
y : array-like of shape (n_samples)
129+
Target variable. Required to train the estimator.
130+
131+
Returns
132+
-------
133+
self
134+
"""
135+
136+
# check input dataframe
137+
X = _is_dataframe(X)
138+
139+
# find numerical variables or check variables entered by user
140+
self.variables = _find_or_check_numerical_variables(X, self.variables)
141+
142+
# train model with all features and cross-validation
143+
model = cross_validate(
144+
self.estimator,
145+
X[self.variables],
146+
y,
147+
cv=self.cv,
148+
scoring=self.scoring,
149+
return_estimator=True,
150+
)
151+
152+
# store initial model performance
153+
self.initial_model_performance_ = model["test_score"].mean()
154+
155+
# Initialize a dataframe that will contain the list of the feature/coeff
156+
# importance for each cross validation fold
157+
feature_importances_cv = pd.DataFrame()
158+
159+
# Populate the feature_importances_cv dataframe with columns containing
160+
# the feature importance values for each model returned by the cross
161+
# validation.
162+
# There are as many columns as folds.
163+
for m in model["estimator"]:
164+
165+
feature_importances_cv[m] = get_feature_importances(m)
166+
167+
# Add the variables as index to feature_importances_cv
168+
feature_importances_cv.index = self.variables
169+
170+
# Aggregate the feature importance returned in each fold
171+
self.feature_importances_ = feature_importances_cv.mean(axis=1)
172+
173+
# Sort the feature importance values descreasingly
174+
self.feature_importances_.sort_values(ascending=False, inplace=True)
175+
176+
# Extract most important feature from the ordered list of features
177+
first_most_important_feature = list(self.feature_importances_.index)[0]
178+
179+
# Run baseline model using only the most important feature
180+
baseline_model = cross_validate(
181+
self.estimator,
182+
X[first_most_important_feature].to_frame(),
183+
y,
184+
cv=self.cv,
185+
scoring=self.scoring,
186+
return_estimator=True,
187+
)
188+
189+
# Save baseline model performance
190+
baseline_model_performance = baseline_model["test_score"].mean()
191+
192+
# list to collect selected features
193+
# It is initialized with the most important feature
194+
self.selected_features_ = [first_most_important_feature]
195+
196+
# dict to collect features and their performance_drift
197+
# It is initialized with the performance drift of
198+
# the most important feature
199+
self.performance_drifts_ = {
200+
first_most_important_feature: 0
201+
}
202+
203+
# loop over the ordered list of features by feature importance starting
204+
# from the second element in the list.
205+
for feature in list(self.feature_importances_.index)[1:]:
206+
207+
# Add feature and train new model
208+
model_tmp = cross_validate(
209+
self.estimator,
210+
X[self.selected_features_ + [feature]],
211+
y,
212+
cv=self.cv,
213+
scoring=self.scoring,
214+
return_estimator=True,
215+
)
216+
217+
# assign new model performance
218+
model_tmp_performance = model_tmp["test_score"].mean()
219+
220+
# Calculate performance drift
221+
performance_drift = model_tmp_performance - baseline_model_performance
222+
223+
# Save feature and performance drift
224+
self.performance_drifts_[feature] = performance_drift
225+
226+
# If new performance model is
227+
if performance_drift > self.threshold:
228+
229+
# add feature to the list of selected features
230+
self.selected_features_.append(feature)
231+
232+
# Update new baseline model performance
233+
baseline_model_performance = model_tmp_performance
234+
235+
self.input_shape_ = X.shape
236+
237+
return self
238+
239+
def transform(self, X: pd.DataFrame):
240+
"""
241+
Return dataframe with selected features.
242+
243+
Parameters
244+
----------
245+
X : pandas dataframe of shape = [n_samples, n_features].
246+
The input dataframe.
247+
248+
Returns
249+
-------
250+
X_transformed: pandas dataframe of shape = [n_samples, n_selected_features]
251+
Pandas dataframe with the selected features.
252+
"""
253+
254+
# check if fit is performed prior to transform
255+
check_is_fitted(self)
256+
257+
# check if input is a dataframe
258+
X = _is_dataframe(X)
259+
260+
# check if number of columns in test dataset matches to train dataset
261+
_check_input_matches_training_df(X, self.input_shape_[1])
262+
263+
# return the dataframe with the selected features
264+
return X[self.selected_features_]

0 commit comments

Comments
 (0)