|
4 | 4 | from sklearn.ensemble import RandomForestClassifier |
5 | 5 | from sklearn.exceptions import NotFittedError |
6 | 6 | from sklearn.linear_model import LinearRegression |
| 7 | +from sklearn.model_selection import KFold, StratifiedKFold |
7 | 8 | from sklearn.tree import DecisionTreeRegressor |
8 | 9 |
|
9 | 10 | from feature_engine.selection import RecursiveFeatureAddition |
@@ -146,11 +147,6 @@ def test_non_fitted_error(df_test): |
146 | 147 | sel.transform(df_test) |
147 | 148 |
|
148 | 149 |
|
149 | | -def test_raises_cv_error(): |
150 | | - with pytest.raises(ValueError): |
151 | | - RecursiveFeatureAddition(RandomForestClassifier(random_state=1), cv=0) |
152 | | - |
153 | | - |
154 | 150 | def test_raises_threshold_error(): |
155 | 151 | with pytest.raises(ValueError): |
156 | 152 | RecursiveFeatureAddition(RandomForestClassifier(random_state=1), threshold=None) |
@@ -225,3 +221,83 @@ def test_automatic_variable_selection(df_test): |
225 | 221 | assert list(sel.performance_drifts_.keys()) == ordered_features |
226 | 222 | # test transform output |
227 | 223 | pd.testing.assert_frame_equal(sel.transform(X), Xtransformed) |
| 224 | + |
| 225 | + |
| 226 | +def test_KFold_generators(df_test): |
| 227 | + |
| 228 | + X, y = df_test |
| 229 | + |
| 230 | + # Kfold |
| 231 | + sel = RecursiveFeatureAddition( |
| 232 | + RandomForestClassifier(random_state=1), |
| 233 | + threshold=0.001, |
| 234 | + cv=KFold(n_splits=3), |
| 235 | + ) |
| 236 | + sel.fit(X, y) |
| 237 | + Xtransformed = sel.transform(X) |
| 238 | + |
| 239 | + # test fit attrs |
| 240 | + assert sel.initial_model_performance_ > 0.995 |
| 241 | + assert isinstance(sel.features_to_drop_, list) |
| 242 | + assert all([x for x in sel.features_to_drop_ if x in X.columns]) |
| 243 | + assert len(sel.features_to_drop_) < X.shape[1] |
| 244 | + assert not Xtransformed.empty |
| 245 | + assert all([x for x in Xtransformed.columns if x not in sel.features_to_drop_]) |
| 246 | + assert isinstance(sel.performance_drifts_, dict) |
| 247 | + assert all([x for x in X.columns if x in sel.performance_drifts_.keys()]) |
| 248 | + assert all( |
| 249 | + [ |
| 250 | + isinstance(sel.performance_drifts_[var], (int, float)) |
| 251 | + for var in sel.performance_drifts_.keys() |
| 252 | + ] |
| 253 | + ) |
| 254 | + |
| 255 | + # Stratfied |
| 256 | + sel = RecursiveFeatureAddition( |
| 257 | + RandomForestClassifier(random_state=1), |
| 258 | + threshold=0.001, |
| 259 | + cv=StratifiedKFold(n_splits=3), |
| 260 | + ) |
| 261 | + sel.fit(X, y) |
| 262 | + Xtransformed = sel.transform(X) |
| 263 | + |
| 264 | + # test fit attrs |
| 265 | + assert sel.initial_model_performance_ > 0.995 |
| 266 | + assert isinstance(sel.features_to_drop_, list) |
| 267 | + assert all([x for x in sel.features_to_drop_ if x in X.columns]) |
| 268 | + assert len(sel.features_to_drop_) < X.shape[1] |
| 269 | + assert not Xtransformed.empty |
| 270 | + assert all([x for x in Xtransformed.columns if x not in sel.features_to_drop_]) |
| 271 | + assert isinstance(sel.performance_drifts_, dict) |
| 272 | + assert all([x for x in X.columns if x in sel.performance_drifts_.keys()]) |
| 273 | + assert all( |
| 274 | + [ |
| 275 | + isinstance(sel.performance_drifts_[var], (int, float)) |
| 276 | + for var in sel.performance_drifts_.keys() |
| 277 | + ] |
| 278 | + ) |
| 279 | + |
| 280 | + # None |
| 281 | + sel = RecursiveFeatureAddition( |
| 282 | + RandomForestClassifier(random_state=1), |
| 283 | + threshold=0.001, |
| 284 | + cv=None, |
| 285 | + ) |
| 286 | + sel.fit(X, y) |
| 287 | + Xtransformed = sel.transform(X) |
| 288 | + |
| 289 | + # test fit attrs |
| 290 | + assert sel.initial_model_performance_ > 0.995 |
| 291 | + assert isinstance(sel.features_to_drop_, list) |
| 292 | + assert all([x for x in sel.features_to_drop_ if x in X.columns]) |
| 293 | + assert len(sel.features_to_drop_) < X.shape[1] |
| 294 | + assert not Xtransformed.empty |
| 295 | + assert all([x for x in Xtransformed.columns if x not in sel.features_to_drop_]) |
| 296 | + assert isinstance(sel.performance_drifts_, dict) |
| 297 | + assert all([x for x in X.columns if x in sel.performance_drifts_.keys()]) |
| 298 | + assert all( |
| 299 | + [ |
| 300 | + isinstance(sel.performance_drifts_[var], (int, float)) |
| 301 | + for var in sel.performance_drifts_.keys() |
| 302 | + ] |
| 303 | + ) |
0 commit comments