Skip to content

Commit cf1ba6f

Browse files
authored
adds support for all type of cross-validation schemes (#267)
1 parent 6cc9866 commit cf1ba6f

10 files changed

+498
-53
lines changed

feature_engine/selection/recursive_feature_addition.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,24 @@ class RecursiveFeatureAddition(BaseSelector):
6262
The threshold must be defined by the user. Bigger thresholds will select less
6363
features.
6464
65-
cv: int, default=3
66-
Cross-validation fold to be used to fit the estimator.
65+
cv: int, cross-validation generator or an iterable, default=3
66+
Determines the cross-validation splitting strategy. Possible inputs for cv are:
67+
68+
- None, to use cross_validate's default 5-fold cross validation
69+
70+
- int, to specify the number of folds in a (Stratified)KFold,
71+
72+
- CV splitter
73+
- (https://scikit-learn.org/stable/glossary.html#term-CV-splitter)
74+
75+
- An iterable yielding (train, test) splits as arrays of indices.
76+
77+
For int/None inputs, if the estimator is a classifier and y is either binary or
78+
multiclass, StratifiedKFold is used. In all other cases, Fold is used. These
79+
splitters are instantiated with shuffle=False so the splits will be the same
80+
across calls.
81+
82+
For more details check Scikit-learn's cross_validate documentation
6783
6884
Attributes
6985
----------
@@ -100,14 +116,11 @@ def __init__(
100116
self,
101117
estimator,
102118
scoring: str = "roc_auc",
103-
cv: int = 3,
119+
cv=3,
104120
threshold: Union[int, float] = 0.01,
105121
variables: Variables = None,
106122
):
107123

108-
if not isinstance(cv, int) or cv < 1:
109-
raise ValueError("cv can only take positive integers bigger than 1")
110-
111124
if not isinstance(threshold, (int, float)):
112125
raise ValueError("threshold can only be integer or float")
113126

feature_engine/selection/recursive_feature_elimination.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,24 @@ class RecursiveFeatureElimination(BaseSelector):
6262
The threshold must be defined by the user. Bigger thresholds will select less
6363
features.
6464
65-
cv: int, default=3
66-
Cross-validation fold to be used to fit the estimator.
65+
cv: int, cross-validation generator or an iterable, default=3
66+
Determines the cross-validation splitting strategy. Possible inputs for cv are:
67+
68+
- None, to use cross_validate's default 5-fold cross validation
69+
70+
- int, to specify the number of folds in a (Stratified)KFold,
71+
72+
- CV splitter
73+
- (https://scikit-learn.org/stable/glossary.html#term-CV-splitter)
74+
75+
- An iterable yielding (train, test) splits as arrays of indices.
76+
77+
For int/None inputs, if the estimator is a classifier and y is either binary or
78+
multiclass, StratifiedKFold is used. In all other cases, Fold is used. These
79+
splitters are instantiated with shuffle=False so the splits will be the same
80+
across calls.
81+
82+
For more details check Scikit-learn's cross_validate documentation
6783
6884
Attributes
6985
----------
@@ -99,14 +115,11 @@ def __init__(
99115
self,
100116
estimator,
101117
scoring: str = "roc_auc",
102-
cv: int = 3,
118+
cv=3,
103119
threshold: Union[int, float] = 0.01,
104120
variables: Variables = None,
105121
):
106122

107-
if not isinstance(cv, int) or cv < 1:
108-
raise ValueError("cv can only take positive integers bigger than 1")
109-
110123
if not isinstance(threshold, (int, float)):
111124
raise ValueError("threshold can only be integer or float")
112125

feature_engine/selection/shuffle_features.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,24 @@ class SelectByShuffling(BaseSelector):
6363
performance drift is smaller than the mean performance drift across all
6464
features.
6565
66-
cv: int, default=3
67-
Desired number of cross-validation fold to be used to fit the estimator.
66+
cv: int, cross-validation generator or an iterable, default=3
67+
Determines the cross-validation splitting strategy. Possible inputs for cv are:
68+
69+
- None, to use cross_validate's default 5-fold cross validation
70+
71+
- int, to specify the number of folds in a (Stratified)KFold,
72+
73+
- CV splitter
74+
- (https://scikit-learn.org/stable/glossary.html#term-CV-splitter)
75+
76+
- An iterable yielding (train, test) splits as arrays of indices.
77+
78+
For int/None inputs, if the estimator is a classifier and y is either binary or
79+
multiclass, StratifiedKFold is used. In all other cases, Fold is used. These
80+
splitters are instantiated with shuffle=False so the splits will be the same
81+
across calls.
82+
83+
For more details check Scikit-learn's cross_validate documentation
6884
6985
random_state: int, default=None
7086
Controls the randomness when shuffling features.
@@ -100,15 +116,12 @@ def __init__(
100116
self,
101117
estimator,
102118
scoring: str = "roc_auc",
103-
cv: int = 3,
119+
cv=3,
104120
threshold: Union[float, int] = None,
105121
variables: Variables = None,
106122
random_state: int = None,
107123
):
108124

109-
if not isinstance(cv, int) or cv < 1:
110-
raise ValueError("cv can only take positive integers bigger than 1")
111-
112125
if threshold and not isinstance(threshold, (int, float)):
113126
raise ValueError("threshold can only be integer or float or None")
114127

feature_engine/selection/single_feature_performance.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,24 @@ class SelectBySingleFeaturePerformance(BaseSelector):
5757
The threshold can be specified by the user. If None, it will be automatically
5858
set to the mean performance value of all features.
5959
60-
cv: int, default=3
61-
Desired number of cross-validation fold to be used to fit the estimator.
60+
cv: int, cross-validation generator or an iterable, default=3
61+
Determines the cross-validation splitting strategy. Possible inputs for cv are:
62+
63+
- None, to use cross_validate's default 5-fold cross validation
64+
65+
- int, to specify the number of folds in a (Stratified)KFold,
66+
67+
- CV splitter
68+
- (https://scikit-learn.org/stable/glossary.html#term-CV-splitter)
69+
70+
- An iterable yielding (train, test) splits as arrays of indices.
71+
72+
For int/None inputs, if the estimator is a classifier and y is either binary or
73+
multiclass, StratifiedKFold is used. In all other cases, Fold is used. These
74+
splitters are instantiated with shuffle=False so the splits will be the same
75+
across calls.
76+
77+
For more details check Scikit-learn's cross_validate documentation
6278
6379
Attributes
6480
----------
@@ -88,14 +104,11 @@ def __init__(
88104
self,
89105
estimator,
90106
scoring: str = "roc_auc",
91-
cv: int = 3,
107+
cv=3,
92108
threshold: Union[int, float] = None,
93109
variables: Variables = None,
94110
):
95111

96-
if not isinstance(cv, int) or cv < 1:
97-
raise ValueError("cv can only take positive integers bigger than 1")
98-
99112
if threshold:
100113
if not isinstance(threshold, (int, float)):
101114
raise ValueError("threshold can only be integer, float or None")

feature_engine/selection/smart_correlation_selection.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,24 @@ class SmartCorrelatedSelection(BaseSelector):
8383
sklearn.metrics. See the model evaluation documentation for more options:
8484
https://scikit-learn.org/stable/modules/model_evaluation.html
8585
86-
cv: int, default=3
87-
Cross-validation fold to be used to fit the estimator.
86+
cv: int, cross-validation generator or an iterable, default=3
87+
Determines the cross-validation splitting strategy. Possible inputs for cv are:
88+
89+
- None, to use cross_validate's default 5-fold cross validation
90+
91+
- int, to specify the number of folds in a (Stratified)KFold,
92+
93+
- CV splitter
94+
- (https://scikit-learn.org/stable/glossary.html#term-CV-splitter)
95+
96+
- An iterable yielding (train, test) splits as arrays of indices.
97+
98+
For int/None inputs, if the estimator is a classifier and y is either binary or
99+
multiclass, StratifiedKFold is used. In all other cases, Fold is used. These
100+
splitters are instantiated with shuffle=False so the splits will be the same
101+
across calls.
102+
103+
For more details check Scikit-learn's cross_validate documentation
88104
89105
Attributes
90106
----------
@@ -124,7 +140,7 @@ def __init__(
124140
selection_method: str = "missing_values",
125141
estimator=None,
126142
scoring: str = "roc_auc",
127-
cv: int = 3,
143+
cv=3,
128144
):
129145

130146
if method not in ["pearson", "spearman", "kendall"]:
@@ -149,9 +165,6 @@ def __init__(
149165
"'variance' or 'model_performance'."
150166
)
151167

152-
if not isinstance(cv, int) or cv < 1:
153-
raise ValueError("cv can only take positive integers bigger than 1")
154-
155168
if selection_method == "model_performance" and estimator is None:
156169
raise ValueError(
157170
"Please provide an estimator, e.g., "

tests/test_selection/test_recursive_feature_addition.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sklearn.ensemble import RandomForestClassifier
55
from sklearn.exceptions import NotFittedError
66
from sklearn.linear_model import LinearRegression
7+
from sklearn.model_selection import KFold, StratifiedKFold
78
from sklearn.tree import DecisionTreeRegressor
89

910
from feature_engine.selection import RecursiveFeatureAddition
@@ -146,11 +147,6 @@ def test_non_fitted_error(df_test):
146147
sel.transform(df_test)
147148

148149

149-
def test_raises_cv_error():
150-
with pytest.raises(ValueError):
151-
RecursiveFeatureAddition(RandomForestClassifier(random_state=1), cv=0)
152-
153-
154150
def test_raises_threshold_error():
155151
with pytest.raises(ValueError):
156152
RecursiveFeatureAddition(RandomForestClassifier(random_state=1), threshold=None)
@@ -225,3 +221,83 @@ def test_automatic_variable_selection(df_test):
225221
assert list(sel.performance_drifts_.keys()) == ordered_features
226222
# test transform output
227223
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
224+
225+
226+
def test_KFold_generators(df_test):
227+
228+
X, y = df_test
229+
230+
# Kfold
231+
sel = RecursiveFeatureAddition(
232+
RandomForestClassifier(random_state=1),
233+
threshold=0.001,
234+
cv=KFold(n_splits=3),
235+
)
236+
sel.fit(X, y)
237+
Xtransformed = sel.transform(X)
238+
239+
# test fit attrs
240+
assert sel.initial_model_performance_ > 0.995
241+
assert isinstance(sel.features_to_drop_, list)
242+
assert all([x for x in sel.features_to_drop_ if x in X.columns])
243+
assert len(sel.features_to_drop_) < X.shape[1]
244+
assert not Xtransformed.empty
245+
assert all([x for x in Xtransformed.columns if x not in sel.features_to_drop_])
246+
assert isinstance(sel.performance_drifts_, dict)
247+
assert all([x for x in X.columns if x in sel.performance_drifts_.keys()])
248+
assert all(
249+
[
250+
isinstance(sel.performance_drifts_[var], (int, float))
251+
for var in sel.performance_drifts_.keys()
252+
]
253+
)
254+
255+
# Stratfied
256+
sel = RecursiveFeatureAddition(
257+
RandomForestClassifier(random_state=1),
258+
threshold=0.001,
259+
cv=StratifiedKFold(n_splits=3),
260+
)
261+
sel.fit(X, y)
262+
Xtransformed = sel.transform(X)
263+
264+
# test fit attrs
265+
assert sel.initial_model_performance_ > 0.995
266+
assert isinstance(sel.features_to_drop_, list)
267+
assert all([x for x in sel.features_to_drop_ if x in X.columns])
268+
assert len(sel.features_to_drop_) < X.shape[1]
269+
assert not Xtransformed.empty
270+
assert all([x for x in Xtransformed.columns if x not in sel.features_to_drop_])
271+
assert isinstance(sel.performance_drifts_, dict)
272+
assert all([x for x in X.columns if x in sel.performance_drifts_.keys()])
273+
assert all(
274+
[
275+
isinstance(sel.performance_drifts_[var], (int, float))
276+
for var in sel.performance_drifts_.keys()
277+
]
278+
)
279+
280+
# None
281+
sel = RecursiveFeatureAddition(
282+
RandomForestClassifier(random_state=1),
283+
threshold=0.001,
284+
cv=None,
285+
)
286+
sel.fit(X, y)
287+
Xtransformed = sel.transform(X)
288+
289+
# test fit attrs
290+
assert sel.initial_model_performance_ > 0.995
291+
assert isinstance(sel.features_to_drop_, list)
292+
assert all([x for x in sel.features_to_drop_ if x in X.columns])
293+
assert len(sel.features_to_drop_) < X.shape[1]
294+
assert not Xtransformed.empty
295+
assert all([x for x in Xtransformed.columns if x not in sel.features_to_drop_])
296+
assert isinstance(sel.performance_drifts_, dict)
297+
assert all([x for x in X.columns if x in sel.performance_drifts_.keys()])
298+
assert all(
299+
[
300+
isinstance(sel.performance_drifts_[var], (int, float))
301+
for var in sel.performance_drifts_.keys()
302+
]
303+
)

0 commit comments

Comments
 (0)