Skip to content

Commit a072085

Browse files
Speed up DropDuplicateFeatures (#614)
* Speed up DropDuplicateFeatures * Extend the test case to test different data types * Fix stylechecks issue
1 parent 31f6215 commit a072085

File tree

2 files changed

+53
-33
lines changed

2 files changed

+53
-33
lines changed

feature_engine/selection/drop_duplicate_features.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from typing import List, Union
23

34
import pandas as pd
@@ -99,7 +100,6 @@ def __init__(
99100
missing_values: str = "ignore",
100101
confirm_variables: bool = False,
101102
):
102-
103103
if missing_values not in ["raise", "ignore"]:
104104
raise ValueError("missing_values takes only values 'raise' or 'ignore'.")
105105

@@ -136,42 +136,30 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None):
136136
# check if dataset contains na
137137
_check_contains_na(X, self.variables_)
138138

139-
# create tuples of duplicated feature groups
140-
self.duplicated_feature_sets_ = []
141-
142-
# set to collect features that are duplicated
143-
self.features_to_drop_ = set() # type: ignore
144-
145-
# create set of examined features
146-
_examined_features = set()
147-
148-
for feature in self.variables_:
139+
# collect duplicate features
140+
_features_hashmap = defaultdict(list)
149141

150-
# append so we can remove when we create the combinations
151-
_examined_features.add(feature)
142+
# hash the features
143+
_X_hash = pd.util.hash_pandas_object(X[self.variables_].T, index=False)
152144

153-
if feature not in self.features_to_drop_:
145+
# group the features by hash
146+
for feature, feature_hash in _X_hash.items():
147+
_features_hashmap[feature_hash].append(feature)
154148

155-
_temp_set = set([feature])
156-
157-
# features that have not been examined, are not currently examined and
158-
# were not found duplicates
159-
_features_to_compare = [
160-
f
161-
for f in self.variables_
162-
if f not in _examined_features.union(self.features_to_drop_)
163-
]
164-
165-
# create combinations:
166-
for f2 in _features_to_compare:
167-
168-
if X[feature].equals(X[f2]):
169-
self.features_to_drop_.add(f2)
170-
_temp_set.add(f2)
149+
# create tuples of duplicated feature groups
150+
self.duplicated_feature_sets_ = [
151+
set(duplicate)
152+
for duplicate in _features_hashmap.values()
153+
if len(duplicate) > 1
154+
]
171155

172-
# if there are duplicated features
173-
if len(_temp_set) > 1:
174-
self.duplicated_feature_sets_.append(_temp_set)
156+
# set to collect features that are duplicated
157+
self.features_to_drop_ = {
158+
item
159+
for duplicates in _features_hashmap.values()
160+
for item in duplicates[1:]
161+
if duplicates and len(duplicates) > 1
162+
}
175163

176164
# save input features
177165
self._get_feature_names_in(X)

tests/test_selection/test_drop_duplicate_features.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,23 @@ def df_duplicate_features_with_na():
4343
return df
4444

4545

46+
@pytest.fixture(scope="module")
47+
def df_duplicate_features_with_different_data_types():
48+
data = {
49+
"A": pd.Series([5.5] * 3).astype("float64"),
50+
"B": 1,
51+
"C": "foo",
52+
"D": pd.Timestamp("20010102"),
53+
"E": pd.Series([1.0] * 3).astype("float32"),
54+
"F": False,
55+
"G": pd.Series([1] * 3, dtype="int8"),
56+
}
57+
58+
df = pd.DataFrame(data)
59+
60+
return df
61+
62+
4663
def test_drop_duplicates_features(df_duplicate_features):
4764
transformer = DropDuplicateFeatures()
4865
X = transformer.fit_transform(df_duplicate_features)
@@ -94,3 +111,18 @@ def test_with_df_with_na(df_duplicate_features_with_na):
94111
{"City", "City2"},
95112
{"Age", "Age2"},
96113
]
114+
115+
116+
def test_with_different_data_types(df_duplicate_features_with_different_data_types):
117+
transformer = DropDuplicateFeatures()
118+
X = transformer.fit_transform(df_duplicate_features_with_different_data_types)
119+
df = pd.DataFrame(
120+
{
121+
"A": pd.Series([5.5] * 3).astype("float64"),
122+
"B": 1,
123+
"C": "foo",
124+
"D": pd.Timestamp("20010102"),
125+
"F": False,
126+
}
127+
)
128+
pd.testing.assert_frame_equal(X, df)

0 commit comments

Comments
 (0)