Skip to content

Commit 599e24f

Browse files
authored
FIX select sample from the targeted class in ClusterCentroids (scikit-learn-contrib#769)
1 parent 181cbc6 commit 599e24f

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

doc/whats_new/v0.7.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ Bug fixes
4949
:class:`imblearn.over_sampling.SMOTENC`.
5050
:pr:`675` by :user:`bganglia <bganglia>`.
5151

52+
- Fix a bug in :class:`imblearn.under_sampling.ClusterCentroids` where
53+
`voting="hard"` could have lead to select a sample from any class instead of
54+
the targeted class.
55+
:pr:`769` by :user:`Guillaume Lemaitre <glemaitre>`.
56+
5257
Enhancements
5358
............
5459

imblearn/under_sampling/_prototype_generation/_cluster_centroids.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,20 @@ def _fit_resample(self, X, y):
166166

167167
X_resampled, y_resampled = [], []
168168
for target_class in np.unique(y):
169+
target_class_indices = np.flatnonzero(y == target_class)
169170
if target_class in self.sampling_strategy_.keys():
170171
n_samples = self.sampling_strategy_[target_class]
171172
self.estimator_.set_params(**{"n_clusters": n_samples})
172-
self.estimator_.fit(X[y == target_class])
173+
self.estimator_.fit(_safe_indexing(X, target_class_indices))
173174
X_new, y_new = self._generate_sample(
174-
X, y, self.estimator_.cluster_centers_, target_class
175+
_safe_indexing(X, target_class_indices),
176+
_safe_indexing(y, target_class_indices),
177+
self.estimator_.cluster_centers_,
178+
target_class,
175179
)
176180
X_resampled.append(X_new)
177181
y_resampled.append(y_new)
178182
else:
179-
target_class_indices = np.flatnonzero(y == target_class)
180183
X_resampled.append(_safe_indexing(X, target_class_indices))
181184
y_resampled.append(_safe_indexing(y, target_class_indices))
182185

imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from scipy import sparse
77

88
from sklearn.cluster import KMeans
9+
from sklearn.datasets import make_classification
910

1011
from imblearn.under_sampling import ClusterCentroids
1112

@@ -121,3 +122,37 @@ def test_cluster_centroids_n_jobs():
121122
cc.fit_resample(X, Y)
122123
assert len(record) == 1
123124
assert "'n_jobs' was deprecated" in record[0].message.args[0]
125+
126+
127+
def test_cluster_centroids_hard_target_class():
128+
# check that the samples selecting by the hard voting corresponds to the
129+
# targeted class
130+
# non-regression test for:
131+
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/738
132+
X, y = make_classification(
133+
n_samples=1000,
134+
n_features=2,
135+
n_informative=1,
136+
n_redundant=0,
137+
n_repeated=0,
138+
n_clusters_per_class=1,
139+
weights=[0.3, 0.7],
140+
class_sep=0.01,
141+
random_state=0,
142+
)
143+
144+
cc = ClusterCentroids(voting="hard", random_state=0)
145+
X_res, y_res = cc.fit_resample(X, y)
146+
147+
minority_class_indices = np.flatnonzero(y == 0)
148+
X_minority_class = X[minority_class_indices]
149+
150+
resampled_majority_class_indices = np.flatnonzero(y_res == 1)
151+
X_res_majority = X_res[resampled_majority_class_indices]
152+
153+
sample_from_minority_in_majority = [
154+
np.all(np.isclose(selected_sample, minority_sample))
155+
for selected_sample in X_res_majority
156+
for minority_sample in X_minority_class
157+
]
158+
assert sum(sample_from_minority_in_majority) == 0

0 commit comments

Comments
 (0)