Skip to content

Commit 36b7a69

Browse files
authored
Fix bug in rfe when only 1 feature left (#639)
* add problematic test * fixes bug * fix code style
1 parent 15375a4 commit 36b7a69

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

feature_engine/selection/recursive_feature_elimination.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
169169
# remember that feature_importances_ is ordered already
170170
for feature in list(self.feature_importances_.index):
171171

172+
# if there is only 1 feature left
173+
if X_tmp.shape[1] == 1:
174+
self.performance_drifts_[feature] = 0
175+
_selected_features.append(feature)
176+
break
177+
172178
# remove feature and train new model
173179
model_tmp = cross_validate(
174180
self.estimator,
@@ -196,11 +202,6 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
196202
# remove feature and adjust initial performance
197203
X_tmp = X_tmp.drop(columns=feature)
198204

199-
if X_tmp.empty is True:
200-
raise ValueError(
201-
"All features have been removed. Try reducing the threshold."
202-
)
203-
204205
baseline_model = cross_validate(
205206
self.estimator,
206207
X_tmp,

tests/test_selection/test_recursive_feature_elimination.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier
4-
from sklearn.linear_model import Lasso, LogisticRegression
4+
from sklearn.linear_model import Lasso, LogisticRegression, LinearRegression
55
from sklearn.tree import DecisionTreeRegressor
66

77
from feature_engine.selection import RecursiveFeatureElimination
@@ -186,3 +186,25 @@ def test_regression(
186186

187187
# test transform output
188188
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
189+
190+
191+
def test_stops_when_only_one_feature_remains():
192+
linear_model = LinearRegression()
193+
194+
# Feature x shows 100% correlation with target variable
195+
# Feature x shows 0% correlation with target variable
196+
# Target variable: y
197+
198+
df = pd.DataFrame(
199+
{
200+
"x": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
201+
"z": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
202+
"y": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
203+
}
204+
)
205+
206+
transformer = RecursiveFeatureElimination(
207+
estimator=linear_model, scoring="r2", cv=3
208+
)
209+
output = transformer.fit_transform(df[["x", "z"]], df["y"])
210+
pd.testing.assert_frame_equal(output, df["x"].to_frame())

0 commit comments

Comments
 (0)