|
| 1 | +RecursiveFeatureAddition |
| 2 | +======================== |
| 3 | + |
| 4 | +API Reference |
| 5 | +------------- |
| 6 | + |
| 7 | +.. autoclass:: feature_engine.selection.RecursiveFeatureAddition |
| 8 | + :members: |
| 9 | + |
| 10 | + |
| 11 | +Example |
| 12 | +------- |
| 13 | + |
| 14 | +.. code:: python |
| 15 | +
|
| 16 | + import pandas as pd |
| 17 | + from sklearn.datasets import load_diabetes |
| 18 | + from sklearn.linear_model import LinearRegression |
| 19 | + from feature_engine.selection import RecursiveFeatureElimination |
| 20 | +
|
| 21 | + # load dataset |
| 22 | + diabetes_X, diabetes_y = load_diabetes(return_X_y=True) |
| 23 | + X = pd.DataFrame(diabetes_X) |
| 24 | + y = pd.DataFrame(diabetes_y) |
| 25 | +
|
| 26 | + # initialize linear regresion estimator |
| 27 | + linear_model = LinearRegression() |
| 28 | +
|
| 29 | + # initialize feature selector |
| 30 | + tr = RecursiveFeatureElimination(estimator=linear_model, scoring="r2", cv=3) |
| 31 | +
|
| 32 | + # fit transformer |
| 33 | + Xt = tr.fit_transform(X, y) |
| 34 | +
|
| 35 | + # get the initial linear model performance, using all features |
| 36 | + tr.initial_model_performance_ |
| 37 | +
|
| 38 | +.. code:: python |
| 39 | +
|
| 40 | + 0.488702767247119 |
| 41 | +
|
| 42 | +.. code:: python |
| 43 | +
|
| 44 | + # Get the performance drift of each feature |
| 45 | + tr.performance_drifts_ |
| 46 | +
|
| 47 | +.. code:: python |
| 48 | +
|
| 49 | + {4: 0, |
| 50 | + 8: 0.2837159006046677, |
| 51 | + 2: 0.1377700238871593, |
| 52 | + 5: 0.0023329006089969906, |
| 53 | + 3: 0.0187608758643259, |
| 54 | + 1: 0.0027994385024313617, |
| 55 | + 7: 0.0026951300105543807, |
| 56 | + 6: 0.002683967832484757, |
| 57 | + 9: 0.0003040126429713075, |
| 58 | + 0: -0.007386876030245182} |
| 59 | +
|
| 60 | +.. code:: python |
| 61 | +
|
| 62 | + # get the selected features |
| 63 | + tr.selected_features_ |
| 64 | +
|
| 65 | +.. code:: python |
| 66 | +
|
| 67 | + [4, 8, 2, 3] |
| 68 | +
|
| 69 | +.. code:: python |
| 70 | +
|
| 71 | + print(Xt.head()) |
| 72 | +
|
| 73 | +.. code:: python |
| 74 | +
|
| 75 | + 4 8 2 3 |
| 76 | + 0 -0.044223 0.019908 0.061696 0.021872 |
| 77 | + 1 -0.008449 -0.068330 -0.051474 -0.026328 |
| 78 | + 2 -0.045599 0.002864 0.044451 -0.005671 |
| 79 | + 3 0.012191 0.022692 -0.011595 -0.036656 |
| 80 | + 4 0.003935 -0.031991 -0.036385 0.021872 |
| 81 | +
|
0 commit comments