|
4 | 4 | ===========
|
5 | 5 |
|
6 | 6 | An illustration of the Tomek links method.
|
7 |
| -
|
8 | 7 | """
|
9 | 8 |
|
10 | 9 | import numpy as np
|
11 | 10 | import matplotlib.pyplot as plt
|
12 | 11 |
|
13 |
| -from sklearn.datasets import make_blobs |
| 12 | +from sklearn.model_selection import train_test_split |
| 13 | +from sklearn.utils import shuffle |
14 | 14 |
|
15 | 15 | from imblearn.under_sampling import TomekLinks
|
16 | 16 |
|
17 | 17 | print(__doc__)
|
18 | 18 |
|
19 |
| - |
20 |
| -# create a synthetic dataset |
21 |
| -X, y = make_blobs(n_samples=500, centers=2, n_features=2, |
22 |
| - random_state=0, center_box=(-5.0, 5.0)) |
| 19 | +rng = np.random.RandomState(0) |
| 20 | +n_samples_1 = 500 |
| 21 | +n_samples_2 = 50 |
| 22 | +X_syn = np.r_[1.5 * rng.randn(n_samples_1, 2), |
| 23 | + 0.5 * rng.randn(n_samples_2, 2) + [2, 2]] |
| 24 | +y_syn = np.array([0] * (n_samples_1) + [1] * (n_samples_2)) |
| 25 | +X_syn, y_syn = shuffle(X_syn, y_syn) |
| 26 | +X_syn_train, X_syn_test, y_syn_train, y_syn_test = train_test_split(X_syn, |
| 27 | + y_syn) |
23 | 28 |
|
24 | 29 | # remove Tomek links
|
25 | 30 | tl = TomekLinks(return_indices=True)
|
26 |
| -X_resampled, y_resampled, idx_resampled = tl.fit_sample(X, y) |
| 31 | +X_resampled, y_resampled, idx_resampled = tl.fit_sample(X_syn, y_syn) |
27 | 32 |
|
28 | 33 | fig = plt.figure()
|
29 | 34 | ax = fig.add_subplot(1, 1, 1)
|
30 | 35 |
|
31 |
| -idx_class_0 = np.flatnonzero(y_resampled == 0) |
32 |
| -idx_class_1 = np.flatnonzero(y_resampled == 1) |
33 |
| -idx_samples_removed = np.setdiff1d(np.flatnonzero(y == 1), |
34 |
| - np.union1d(idx_class_0, idx_class_1)) |
35 |
| - |
36 |
| -plt.scatter(X[idx_class_0, 0], X[idx_class_0, 1], |
37 |
| - c='g', alpha=.8, label='Class #0') |
38 |
| -plt.scatter(X[idx_class_1, 0], X[idx_class_1, 1], |
39 |
| - c='b', alpha=.8, label='Class #1') |
40 |
| -plt.scatter(X[idx_samples_removed, 0], X[idx_samples_removed, 1], |
41 |
| - c='r', alpha=.8, label='Removed samples') |
42 |
| - |
| 36 | +idx_samples_removed = np.setdiff1d(np.arange(X_syn.shape[0]), |
| 37 | + idx_resampled) |
| 38 | +idx_class_0 = y_resampled == 0 |
| 39 | +plt.scatter(X_resampled[idx_class_0, 0], X_resampled[idx_class_0, 1], |
| 40 | + alpha=.8, label='Class #0') |
| 41 | +plt.scatter(X_resampled[~idx_class_0, 0], X_resampled[~idx_class_0, 1], |
| 42 | + alpha=.8, label='Class #1') |
| 43 | +plt.scatter(X_syn[idx_samples_removed, 0], X_syn[idx_samples_removed, 1], |
| 44 | + alpha=.8, label='Removed samples') |
| 45 | + |
| 46 | +# make nice plotting |
43 | 47 | ax.spines['top'].set_visible(False)
|
44 | 48 | ax.spines['right'].set_visible(False)
|
45 | 49 | ax.get_xaxis().tick_bottom()
|
46 | 50 | ax.get_yaxis().tick_left()
|
47 | 51 | ax.spines['left'].set_position(('outward', 10))
|
48 | 52 | ax.spines['bottom'].set_position(('outward', 10))
|
| 53 | +ax.set_xlim([-5, 5]) |
| 54 | +ax.set_ylim([-5, 5]) |
| 55 | +plt.yticks(range(-5, 6)) |
| 56 | +plt.xticks(range(-5, 6)) |
49 | 57 |
|
50 | 58 | plt.title('Under-sampling removing Tomek links')
|
51 | 59 | plt.legend()
|
|
0 commit comments