Skip to content

Commit 4ed22a2

Browse files
Added tests for instance selection algorithms #174
1 parent 4275f82 commit 4ed22a2

File tree

8 files changed

+158
-23
lines changed

8 files changed

+158
-23
lines changed

instance_selection/_DROP3.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ def filter(self, samples, y):
138138
if np.array_equal(neigh, x_sample):
139139
break
140140
a_neighs = a_neighs[:index_a] + a_neighs[index_a + 1:]
141-
try:
142-
assert len(a_neighs) == self.nearest_neighbors
143-
except AssertionError:
144-
breakpoint()
145141
# Find a new neigh for the associate
146142
remaining_samples = [x for x, _, _ in initial_distances]
147143
knn = NearestNeighbors(
@@ -162,20 +158,12 @@ def filter(self, samples, y):
162158
a_neighs.append(pos_neigh)
163159
break
164160

165-
try:
166-
assert len(a_neighs) == self.nearest_neighbors + 1
167-
except AssertionError:
168-
print('Duplicated instances')
169-
170161
samples_info[tuple(a_associate_of_x)][0] = a_neighs
171162

172163
# Add a_associate to the associates list of the new neigh
173164
new_neigh = a_neighs[-1]
174-
try:
175-
samples_info[tuple(new_neigh)][1].append(
176-
a_associate_of_x)
177-
except TypeError:
178-
pass
165+
samples_info[tuple(new_neigh)][1].append(
166+
a_associate_of_x)
179167

180168
samples = pd.DataFrame([x for x, _, _ in initial_distances],
181169
columns=self.x_attr)

instance_selection/_ENN.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, nearest_neighbors=3, power_parameter=2):
1818
self.power_parameter = power_parameter
1919
self.x_attr = None
2020

21-
def neighs(self, s_samples, s_targets, index, removed):
21+
def __neighs(self, s_samples, s_targets, index, removed):
2222
x_sample = s_samples[index - removed]
2323
x_target = s_targets[index - removed]
2424
knn = NearestNeighbors(n_jobs=-1,
@@ -56,8 +56,8 @@ def filter(self, samples, y):
5656
removed = 0
5757

5858
for index in range(size):
59-
_, x_target, targets_not_x, samples_not_x, neigh_ind = self.neighs(
60-
s_samples, s_targets, index, removed)
59+
_, x_target, targets_not_x, samples_not_x, neigh_ind = \
60+
self.__neighs(s_samples, s_targets, index, removed)
6161
y_targets = np.ravel(
6262
np.array([targets_not_x[x] for x in neigh_ind[0]])).astype(int)
6363
count = np.bincount(y_targets)
@@ -100,9 +100,9 @@ def filter_original_complete(self, original, original_y, complete,
100100

101101
for index in range(size):
102102
x_sample, x_target, targets_not_x, samples_not_x, neigh_ind = \
103-
self.neighs(s_samples, s_targets, index, removed)
103+
self.__neighs(s_samples, s_targets, index, removed)
104104
y_targets = [targets_not_x[x] for x in neigh_ind[0]]
105-
count = np.bincount(y_targets)
105+
count = np.bincount(np.ravel(y_targets))
106106
max_class = np.where(count == np.amax(count))[0][0]
107107
if max_class != x_target:
108108
delete = True

instance_selection/_LocalSets.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# @Version: 2.0
77
import sys
88

9+
import numpy as np
910
import pandas as pd
1011
from sklearn.metrics import pairwise_distances
1112

@@ -65,6 +66,12 @@ def usefulness(self, e):
6566
def get_local_sets(self):
6667
return self.local_sets
6768

69+
@staticmethod
70+
def check_frame_to_numpy(y):
71+
if isinstance(y, pd.DataFrame):
72+
return np.ravel(y.to_numpy())
73+
return y
74+
6875

6976
class LSSm(LocalSets):
7077
def __init__(self):
@@ -73,8 +80,8 @@ def __init__(self):
7380
def filter(self, instances, labels):
7481
names = instances.keys()
7582
instances = instances.to_numpy()
76-
import numpy as np
7783
instances = [np.ravel(i) for i in instances]
84+
labels = self.check_frame_to_numpy(labels)
7885
if len(instances) != len(labels):
7986
raise ValueError(
8087
f'The dimension of the labeled data must be the same as the '
@@ -113,6 +120,7 @@ def filter(self, instances, labels):
113120
f'number of labels given. {len(instances)} != {len(labels)}'
114121
)
115122
self.n_id = len(instances)
123+
labels = self.check_frame_to_numpy(labels)
116124
lssm = LSSm()
117125
instances, labels = lssm.filter(instances, labels)
118126
instances = instances.to_numpy()

instance_selection/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._transformer import transform, transform_original_complete,\
1+
from ._transformer import transform, transform_original_complete, \
22
delete_multiple_element
33

44
__all__ = [

is-ssl.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ name: IS-SSL
22
channels:
33
- conda-forge
44
- default
5+
- anaconda
56
dependencies:
67
- numpy=1.20.3
78
- scikit-learn=0.24.2
89
- matplotlib=3.4.3
910
- pandas=1.3.4
1011
- yagmail=0.15.277
11-
- scipy~=1.7.1
12+
- scipy=1.7.1
13+
- pytest=7.1.1

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ scikit-learn~=0.24.2
33
matplotlib~=3.4.3
44
pandas~=1.3.4
55
yagmail~=0.15.277
6-
scipy~=1.7.1
6+
scipy~=1.7.1
7+
pytest~=7.1.1

tests/InstanceSelection.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#!/usr/bin/env python
2+
# -*- coding:utf-8 -*-
3+
# @Filename: InstanceSelection.py
4+
# @Author: Daniel Puente Ramírez
5+
# @Time: 15/4/22 16:20
6+
7+
import random
8+
9+
import numpy as np
10+
import pandas as pd
11+
import pytest
12+
from sklearn.datasets import load_iris
13+
14+
from instance_selection import ENN, CNN, RNN, ICF, MSS, DROP3, LSSm, LSBo
15+
16+
17+
def to_dataframe(y):
18+
if not isinstance(y, pd.DataFrame):
19+
return pd.DataFrame(y)
20+
return y
21+
22+
23+
@pytest.fixture
24+
def iris_dataset():
25+
x, y = load_iris(return_X_y=True, as_frame=True)
26+
y = to_dataframe(y)
27+
return x, y
28+
29+
30+
@pytest.fixture
31+
def iris_dataset_ss():
32+
x, y = load_iris(return_X_y=True, as_frame=True)
33+
y = to_dataframe(y)
34+
li = list(set(range(x.shape[0])))
35+
36+
unlabeled = random.sample(li, int(x.shape[0] * 0.3))
37+
labeled = [x for x in range(x.shape[0]) if x not in unlabeled]
38+
39+
complete = x
40+
complete_labels = y
41+
42+
original = x.loc[labeled]
43+
original_labels = y.loc[labeled]
44+
45+
return original, original_labels, complete, complete_labels
46+
47+
48+
def base(x, y, algorithm, params=None):
49+
assert isinstance(x, pd.DataFrame) and isinstance(y, pd.DataFrame)
50+
model = algorithm(**params) if params is not None else algorithm()
51+
x_filtered, y_filtered = model.filter(x, y)
52+
53+
assert x_filtered.shape[1] == x.shape[1] and y_filtered.shape[1] == \
54+
y.shape[1]
55+
56+
assert x_filtered.shape[0] == y_filtered.shape[0]
57+
assert x_filtered.shape[0] < x.shape[0]
58+
59+
60+
def test_enn_original(iris_dataset):
61+
x, y = iris_dataset
62+
base(x, y, ENN, {'nearest_neighbors': 3, 'power_parameter': 2})
63+
64+
65+
def test_cnn(iris_dataset):
66+
x, y = iris_dataset
67+
base(x, y, CNN)
68+
69+
70+
def test_rnn(iris_dataset):
71+
x, y = iris_dataset
72+
base(x, y, RNN)
73+
74+
75+
def test_icf(iris_dataset):
76+
x, y = iris_dataset
77+
base(x, y, ICF, {'nearest_neighbors': 3, 'power_parameter': 2})
78+
79+
80+
def test_mss(iris_dataset):
81+
x, y = iris_dataset
82+
base(x, y, MSS)
83+
84+
85+
def test_drop3(iris_dataset):
86+
x, y = iris_dataset
87+
base(x, y, DROP3, {'nearest_neighbors': 3, 'power_parameter': 2})
88+
89+
90+
def test_local_sets_lssm(iris_dataset):
91+
x, y = iris_dataset
92+
base(x, y, LSSm)
93+
94+
95+
def test_local_sets_lsbo(iris_dataset):
96+
x, y = iris_dataset
97+
base(x, y, LSBo)
98+
99+
100+
def test_enn_ss(iris_dataset_ss):
101+
original, original_labels, complete, complete_labels, = iris_dataset_ss
102+
103+
model = ENN()
104+
x, y = model.filter_original_complete(original, original_labels,
105+
complete, complete_labels)
106+
107+
new_orig = []
108+
for ori in original.to_numpy():
109+
for index, x_sample in enumerate(x.to_numpy()):
110+
if np.array_equal(ori, x_sample):
111+
new_orig.append(index)
112+
break
113+
114+
a = np.ravel(y.loc[new_orig].to_numpy())
115+
o = np.ravel(original_labels.to_numpy())
116+
assert np.array_equal(o, a)
117+
assert complete.shape[1] == x.shape[1]
118+
assert complete.shape[0] >= x.shape[0]
119+
120+
121+
def test_different_len(iris_dataset):
122+
x, y = iris_dataset
123+
y = y.loc[:-1]
124+
model1 = LSSm()
125+
with pytest.raises(ValueError):
126+
model1.filter(x, y)
127+
model2 = LSBo()
128+
with pytest.raises(ValueError):
129+
model2.filter(x, y)

tests/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python
2+
# -*- coding:utf-8 -*-
3+
# @Filename: __init__.py
4+
# @Author: Daniel Puente Ramírez
5+
# @Time: 15/4/22 16:19
6+
7+
"""Python module for testing"""

0 commit comments

Comments
 (0)