11from typing import List , Union
22
33import pandas as pd
4- from sklearn .base import TransformerMixin , BaseEstimator
5- from sklearn .utils .validation import check_is_fitted
64
75from feature_engine .dataframe_checks import (
86 _is_dataframe ,
9- _check_input_matches_training_df ,
107 _check_contains_na ,
118)
129from feature_engine .variable_manipulation import (
1310 _find_or_check_numerical_variables ,
1411 _check_input_parameter_variables ,
1512)
13+ from feature_engine .selection .base_selector import BaseSelector
1614
1715Variables = Union [None , int , str , List [Union [str , int ]]]
1816
1917
20- class DropCorrelatedFeatures (BaseEstimator , TransformerMixin ):
18+ class DropCorrelatedFeatures (BaseSelector ):
2119 """
2220 DropCorrelatedFeatures() finds and removes correlated features. Correlation is
2321 calculated with `pandas.corr()`.
@@ -52,14 +50,11 @@ class DropCorrelatedFeatures(BaseEstimator, TransformerMixin):
5250
5351 Attributes
5452 ----------
55- correlated_features_ :
56- Set with the correlated features.
53+ features_to_drop_ :
54+ Set with the correlated features that will be dropped .
5755
5856 correlated_feature_sets_:
59- Groups of correlated features. Each list is a group of correlated features.
60-
61- correlated_matrix_:
62- The correlation matrix.
57+ Groups of correlated features. Each list is a group of correlated features.
6358
6459 Methods
6560 -------
@@ -128,20 +123,20 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None):
128123 _check_contains_na (X , self .variables )
129124
130125 # set to collect features that are correlated
131- self .correlated_features_ = set ()
126+ self .features_to_drop_ = set ()
132127
133128 # create tuples of correlated feature groups
134129 self .correlated_feature_sets_ = []
135130
136131 # the correlation matrix
137- self . correlated_matrix_ = X [self .variables ].corr (method = self .method )
132+ _correlated_matrix = X [self .variables ].corr (method = self .method )
138133
139134 # create set of examined features, helps to determine feature combinations
140135 # to evaluate below
141136 _examined_features = set ()
142137
143138 # for each feature in the dataset (columns of the correlation matrix)
144- for feature in self . correlated_matrix_ .columns :
139+ for feature in _correlated_matrix .columns :
145140
146141 if feature not in _examined_features :
147142
@@ -155,20 +150,18 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None):
155150 # features that have not been examined, are not currently examined and
156151 # were not found correlated
157152 _features_to_compare = [
158- f
159- for f in self .correlated_matrix_ .columns
160- if f not in _examined_features
153+ f for f in _correlated_matrix .columns if f not in _examined_features
161154 ]
162155
163156 # create combinations:
164157 for f2 in _features_to_compare :
165158
166159 # if the correlation is higher than the threshold
167160 # we are interested in absolute correlation coefficient value
168- if abs (self . correlated_matrix_ .loc [f2 , feature ]) > self .threshold :
161+ if abs (_correlated_matrix .loc [f2 , feature ]) > self .threshold :
169162
170163 # add feature (f2) to our correlated set
171- self .correlated_features_ .add (f2 )
164+ self .features_to_drop_ .add (f2 )
172165 _temp_set .add (f2 )
173166 _examined_features .add (f2 )
174167
@@ -180,35 +173,10 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None):
180173
181174 return self
182175
183- def transform (self , X ):
184- """
185- Drop the correlated features from a dataframe.
186-
187- Parameters
188- ----------
189- X : pandas dataframe of shape = [n_samples, n_features].
190- The input samples.
191-
192- Returns
193- -------
194- X_transformed : pandas dataframe
195- shape = [n_samples, n_features - (correlated features)]
196- The transformed dataframe with the remaining subset of variables.
197- """
198- # check if fit is performed prior to transform
199- check_is_fitted (self )
200-
201- # check if input is a dataframe
202- X = _is_dataframe (X )
203-
204- # check if number of columns in test dataset matches to train dataset
205- _check_input_matches_training_df (X , self .input_shape_ [1 ])
206-
207- if self .missing_values == "raise" :
208- # check if dataset contains na
209- _check_contains_na (X , self .variables )
210-
211- # returned non-correlated features
212- X = X .drop (columns = self .correlated_features_ )
176+ # Ugly work around to import the docstring for Sphinx, otherwise not necessary
177+ def transform (self , X : pd .DataFrame ) -> pd .DataFrame :
178+ X = super ().transform (X )
213179
214180 return X
181+
182+ transform .__doc__ = BaseSelector .transform .__doc__
0 commit comments