Skip to content

Commit a951029

Browse files
authored
DOC/EXA solve Tomek examples (scikit-learn-contrib#263)
1 parent 8d9d21a commit a951029

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

examples/applications/plot_over_sampling_benchmark_lfw.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def fit_sample(self, X, y):
5454
y[y == majority_person] = 0
5555
y[y == minority_person] = 1
5656

57-
5857
classifier = ['3NN', neighbors.KNeighborsClassifier(3)]
5958

6059
samplers = [

examples/under-sampling/plot_tomek_links.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,56 @@
44
===========
55
66
An illustration of the Tomek links method.
7-
87
"""
98

109
import numpy as np
1110
import matplotlib.pyplot as plt
1211

13-
from sklearn.datasets import make_blobs
12+
from sklearn.model_selection import train_test_split
13+
from sklearn.utils import shuffle
1414

1515
from imblearn.under_sampling import TomekLinks
1616

1717
print(__doc__)
1818

19-
20-
# create a synthetic dataset
21-
X, y = make_blobs(n_samples=500, centers=2, n_features=2,
22-
random_state=0, center_box=(-5.0, 5.0))
19+
rng = np.random.RandomState(0)
20+
n_samples_1 = 500
21+
n_samples_2 = 50
22+
X_syn = np.r_[1.5 * rng.randn(n_samples_1, 2),
23+
0.5 * rng.randn(n_samples_2, 2) + [2, 2]]
24+
y_syn = np.array([0] * (n_samples_1) + [1] * (n_samples_2))
25+
X_syn, y_syn = shuffle(X_syn, y_syn)
26+
X_syn_train, X_syn_test, y_syn_train, y_syn_test = train_test_split(X_syn,
27+
y_syn)
2328

2429
# remove Tomek links
2530
tl = TomekLinks(return_indices=True)
26-
X_resampled, y_resampled, idx_resampled = tl.fit_sample(X, y)
31+
X_resampled, y_resampled, idx_resampled = tl.fit_sample(X_syn, y_syn)
2732

2833
fig = plt.figure()
2934
ax = fig.add_subplot(1, 1, 1)
3035

31-
idx_class_0 = np.flatnonzero(y_resampled == 0)
32-
idx_class_1 = np.flatnonzero(y_resampled == 1)
33-
idx_samples_removed = np.setdiff1d(np.flatnonzero(y == 1),
34-
np.union1d(idx_class_0, idx_class_1))
35-
36-
plt.scatter(X[idx_class_0, 0], X[idx_class_0, 1],
37-
c='g', alpha=.8, label='Class #0')
38-
plt.scatter(X[idx_class_1, 0], X[idx_class_1, 1],
39-
c='b', alpha=.8, label='Class #1')
40-
plt.scatter(X[idx_samples_removed, 0], X[idx_samples_removed, 1],
41-
c='r', alpha=.8, label='Removed samples')
42-
36+
idx_samples_removed = np.setdiff1d(np.arange(X_syn.shape[0]),
37+
idx_resampled)
38+
idx_class_0 = y_resampled == 0
39+
plt.scatter(X_resampled[idx_class_0, 0], X_resampled[idx_class_0, 1],
40+
alpha=.8, label='Class #0')
41+
plt.scatter(X_resampled[~idx_class_0, 0], X_resampled[~idx_class_0, 1],
42+
alpha=.8, label='Class #1')
43+
plt.scatter(X_syn[idx_samples_removed, 0], X_syn[idx_samples_removed, 1],
44+
alpha=.8, label='Removed samples')
45+
46+
# make nice plotting
4347
ax.spines['top'].set_visible(False)
4448
ax.spines['right'].set_visible(False)
4549
ax.get_xaxis().tick_bottom()
4650
ax.get_yaxis().tick_left()
4751
ax.spines['left'].set_position(('outward', 10))
4852
ax.spines['bottom'].set_position(('outward', 10))
53+
ax.set_xlim([-5, 5])
54+
ax.set_ylim([-5, 5])
55+
plt.yticks(range(-5, 6))
56+
plt.xticks(range(-5, 6))
4957

5058
plt.title('Under-sampling removing Tomek links')
5159
plt.legend()

0 commit comments

Comments
 (0)