Skip to content

Commit f2cf52e

Browse files
committed
add: Mean Max Loss and Max Loss multilabel SVM query strategies from Li et al. added
1 parent 21b29f1 commit f2cf52e

File tree

3 files changed

+106
-9
lines changed

3 files changed

+106
-9
lines changed

examples/multilabel_svm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
plt.scatter(X[y[:, 1] == 1, 0], X[y[:, 1] == 1, 1],
2525
facecolors='none', edgecolors='r', s=100, linewidths=2, label='class 2')
2626
plt.legend()
27-
#plt.show()
27+
plt.show()
2828

2929
learner = ActiveLearner(
3030
estimator=OneVsRestClassifier(SVC(probability=True)),

modAL/multilabel.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from modAL.utils.data import modALinput
77
from typing import Tuple, Optional
8+
from itertools import combinations
89

910

1011
def _SVM_loss(multiclass_classifier: OneVsRestClassifier,
@@ -30,7 +31,7 @@ def _SVM_loss(multiclass_classifier: OneVsRestClassifier,
3031
if most_certain_classes is None:
3132
cls_mtx = 2*np.eye(n_classes, n_classes) - 1
3233
loss_mtx = np.maximum(1-np.dot(predictions, cls_mtx), 0)
33-
return loss_mtx.mean(axis=0)
34+
return loss_mtx.mean(axis=1)
3435
else:
3536
cls_mtx = -np.ones(shape=(len(X), n_classes))
3637
for inst_idx, most_certain_class in enumerate(most_certain_classes):
@@ -63,10 +64,80 @@ def SVM_binary_minimum(classifier: BaseEstimator,
6364
def max_loss(classifier: BaseEstimator,
6465
X_pool: modALinput,
6566
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
66-
pass
67+
68+
"""
69+
Max Loss query strategy for SVM multilabel classification.
70+
71+
For more details on this query strategy, see
72+
Li et al., Multilabel SVM active learning for image classification
73+
(http://dx.doi.org/10.1109/ICIP.2004.1421535)
74+
75+
Args:
76+
classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
77+
such as the ones from sklearn.svm. Although the function will execute for other models as well,
78+
the mathematical calculations in Li et al. work only for SVM-s.
79+
X: The pool of samples to query from.
80+
81+
Returns:
82+
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
83+
"""
84+
85+
most_certain_classes = classifier.predict_proba(X_pool).argmax(axis=1)
86+
loss = _SVM_loss(classifier, X_pool, most_certain_classes=most_certain_classes)
87+
88+
assert len(X_pool) >= n_instances, 'n_instances cannot be larger than len(X_pool)'
89+
90+
if n_instances == 1:
91+
query_idx = np.argmax(loss)
92+
return query_idx, X_pool[query_idx]
93+
else:
94+
max_val = -np.inf
95+
max_idx = None
96+
for subset_idx in combinations(range(len(X_pool)), n_instances):
97+
subset_sum = loss[list(subset_idx)].sum()
98+
if subset_sum > max_val:
99+
max_val = subset_sum
100+
max_idx = subset_idx
101+
102+
query_idx = np.array(max_idx)
103+
return query_idx, X_pool[query_idx]
67104

68105

69106
def mean_max_loss(classifier: BaseEstimator,
70107
X_pool: modALinput,
71108
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
72-
pass
109+
"""
110+
Mean Max Loss query strategy for SVM multilabel classification.
111+
112+
For more details on this query strategy, see
113+
Li et al., Multilabel SVM active learning for image classification
114+
(http://dx.doi.org/10.1109/ICIP.2004.1421535)
115+
116+
Args:
117+
classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
118+
such as the ones from sklearn.svm. Although the function will execute for other models as well,
119+
the mathematical calculations in Li et al. work only for SVM-s.
120+
X: The pool of samples to query from.
121+
122+
Returns:
123+
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
124+
"""
125+
126+
loss = _SVM_loss(classifier, X_pool)
127+
128+
assert len(X_pool) >= n_instances, 'n_instances cannot be larger than len(X_pool)'
129+
130+
if n_instances == 1:
131+
query_idx = np.argmax(loss)
132+
return query_idx, X_pool[query_idx]
133+
else:
134+
max_val = -np.inf
135+
max_idx = None
136+
for subset_idx in combinations(range(len(X_pool)), n_instances):
137+
subset_sum = loss[list(subset_idx)].sum()
138+
if subset_sum > max_val:
139+
max_val = subset_sum
140+
max_idx = subset_idx
141+
142+
query_idx = np.array(max_idx)
143+
return query_idx, X_pool[query_idx]

tests/core_tests.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -928,17 +928,43 @@ def test_vote(self):
928928

929929
class TestMultilabel(unittest.TestCase):
930930
def test_SVM_loss(self):
931-
for n_classes in range(3, 10):
932-
for n_instances in range(5, 10):
931+
for n_classes in range(2, 10):
932+
for n_instances in range(1, 10):
933933
X_training = np.random.rand(n_instances, 5)
934934
y_training = np.random.randint(0, 2, size=(n_instances, n_classes))
935935
X_pool = np.random.rand(n_instances, 5)
936936
y_pool = np.random.randint(0, 2, size=(n_instances, n_classes))
937937
classifier = OneVsRestClassifier(SVC())
938938
classifier.fit(X_training, y_training)
939-
loss = modAL.multilabel._SVM_loss(classifier, X_pool)
940-
loss = modAL.multilabel._SVM_loss(classifier, X_pool,
941-
most_certain_classes=np.random.randint(0, n_classes, size=(n_instances)))
939+
avg_loss = modAL.multilabel._SVM_loss(classifier, X_pool)
940+
mcc_loss = modAL.multilabel._SVM_loss(classifier, X_pool,
941+
most_certain_classes=np.random.randint(0, n_classes, size=(n_instances)))
942+
self.assertEqual(avg_loss.shape, (len(X_pool), ))
943+
self.assertEqual(mcc_loss.shape, (len(X_pool),))
944+
945+
def test_mean_max_loss(self):
946+
for n_classes in range(2, 10):
947+
for n_pool_instances in range(1, 10):
948+
for n_query_instances in range(1, min(n_pool_instances, 3)):
949+
X_training = np.random.rand(n_pool_instances, 5)
950+
y_training = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
951+
X_pool = np.random.rand(n_pool_instances, 5)
952+
y_pool = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
953+
classifier = OneVsRestClassifier(SVC())
954+
classifier.fit(X_training, y_training)
955+
query_idx, query_inst = modAL.multilabel.mean_max_loss(classifier, X_pool, n_query_instances)
956+
957+
def test_max_loss(self):
958+
for n_classes in range(2, 10):
959+
for n_pool_instances in range(1, 10):
960+
for n_query_instances in range(1, min(n_pool_instances, 3)):
961+
X_training = np.random.rand(n_pool_instances, 5)
962+
y_training = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
963+
X_pool = np.random.rand(n_pool_instances, 5)
964+
y_pool = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
965+
classifier = OneVsRestClassifier(SVC(probability=True))
966+
classifier.fit(X_training, y_training)
967+
query_idx, query_inst = modAL.multilabel.max_loss(classifier, X_pool, n_query_instances)
942968

943969

944970
class TestExamples(unittest.TestCase):

0 commit comments

Comments
 (0)