Skip to content

Commit 37c68bb

Browse files
committed
add: _SVM_loss utility function for MeanMaxLoss and MaxLoss multilabel strategies added
1 parent 20000b8 commit 37c68bb

File tree

3 files changed

+73
-7
lines changed

3 files changed

+73
-7
lines changed

examples/multilabel_svm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from modAL.multilabel import SVM_binary_minimum
66

77
from sklearn.multiclass import OneVsRestClassifier
8-
from sklearn.svm import LinearSVC
8+
from sklearn.svm import SVC
99

1010
n_samples = 500
1111
X = np.random.normal(size=(n_samples, 2))
@@ -24,10 +24,10 @@
2424
plt.scatter(X[y[:, 1] == 1, 0], X[y[:, 1] == 1, 1],
2525
facecolors='none', edgecolors='r', s=100, linewidths=2, label='class 2')
2626
plt.legend()
27-
plt.show()
27+
#plt.show()
2828

2929
learner = ActiveLearner(
30-
estimator=OneVsRestClassifier(LinearSVC()),
30+
estimator=OneVsRestClassifier(SVC(probability=True)),
3131
query_strategy=SVM_binary_minimum,
3232
X_training=X_initial, y_training=y_initial
3333
)

modAL/multilabel.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,72 @@
11
import numpy as np
22

33
from sklearn.base import BaseEstimator
4+
from sklearn.multiclass import OneVsRestClassifier
45

56
from modAL.utils.data import modALinput
6-
from typing import Tuple
7+
from typing import Tuple, Optional
78

89

9-
def SVM_binary_minimum(classifier: BaseEstimator, X_pool: modALinput) -> Tuple[np.ndarray, modALinput]:
10+
def _SVM_loss(multiclass_classifier: OneVsRestClassifier,
11+
X: modALinput,
12+
most_certain_classes: Optional[int] = None) -> np.ndarray:
13+
"""
14+
Utility function for max_loss and mean_max_loss strategies.
15+
16+
Args:
17+
multiclass_classifier: sklearn.multiclass.OneVsRestClassifier instance for which the loss
18+
is to be calculated.
19+
X: The pool of samples to query from.
20+
most_certain_classes: optional, indexes of most certainly predicted class for each instance.
21+
If None, loss is calculated for all classes.
22+
23+
Returns:
24+
np.ndarray of shape (n_instances, ), losses for the instances in X.
25+
26+
"""
27+
predictions = 2*multiclass_classifier.predict(X)-1
28+
n_classes = len(multiclass_classifier.classes_)
29+
30+
if most_certain_classes is None:
31+
cls_mtx = 2*np.eye(n_classes, n_classes) - 1
32+
loss_mtx = np.maximum(1-np.dot(predictions, cls_mtx), 0)
33+
return loss_mtx.mean(axis=0)
34+
else:
35+
cls_mtx = -np.ones(shape=(len(X), n_classes))
36+
for inst_idx, most_certain_class in enumerate(most_certain_classes):
37+
cls_mtx[inst_idx, most_certain_class] = 1
38+
39+
cls_loss = np.maximum(1 - np.multiply(cls_mtx, predictions), 0).sum(axis=1)
40+
return cls_loss
41+
42+
43+
def SVM_binary_minimum(classifier: BaseEstimator,
44+
X_pool: modALinput) -> Tuple[np.ndarray, modALinput]:
1045
"""
1146
SVM binary minimum multilabel active learning strategy. For details see the paper
1247
Klaus Brinker, On Active Learning in Multi-label Classification
1348
(https://link.springer.com/chapter/10.1007%2F3-540-31314-1_24)
1449
1550
Args:
1651
classifier: The multilabel classifier for which the labels are to be queried. Must be an SVM model
17-
such as the ones from sklearn.svm.
52+
such as the ones from sklearn.svm.
1853
X: The pool of samples to query from.
1954
2055
Returns:
2156
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
2257
"""
2358
min_abs_dist = np.min(np.abs(classifier.estimator.decision_function(X_pool)), axis=1)
2459
query_idx = np.argmin(min_abs_dist)
25-
return query_idx, X_pool[query_idx]
60+
return query_idx, X_pool[query_idx]
61+
62+
63+
def max_loss(classifier: BaseEstimator,
64+
X_pool: modALinput,
65+
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
66+
pass
67+
68+
69+
def mean_max_loss(classifier: BaseEstimator,
70+
X_pool: modALinput,
71+
n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
72+
pass

tests/core_tests.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
import modAL.utils.selection
1212
import modAL.utils.validation
1313
import modAL.utils.combination
14+
import modAL.multilabel
1415

1516
from copy import deepcopy
1617
from itertools import chain, product
1718
from collections import namedtuple
19+
1820
from sklearn.ensemble import RandomForestClassifier
1921
from sklearn.exceptions import NotFittedError
2022
from sklearn.metrics import confusion_matrix
23+
from sklearn.svm import SVC
24+
from sklearn.multiclass import OneVsRestClassifier
2125
from scipy.stats import entropy, norm
2226
from scipy.special import ndtr
2327
from scipy import sparse as sp
@@ -922,6 +926,21 @@ def test_vote(self):
922926
)
923927

924928

929+
class TestMultilabel(unittest.TestCase):
930+
def test_SVM_loss(self):
931+
for n_classes in range(3, 10):
932+
for n_instances in range(5, 10):
933+
X_training = np.random.rand(n_instances, 5)
934+
y_training = np.random.randint(0, 2, size=(n_instances, n_classes))
935+
X_pool = np.random.rand(n_instances, 5)
936+
y_pool = np.random.randint(0, 2, size=(n_instances, n_classes))
937+
classifier = OneVsRestClassifier(SVC())
938+
classifier.fit(X_training, y_training)
939+
loss = modAL.multilabel._SVM_loss(classifier, X_pool)
940+
loss = modAL.multilabel._SVM_loss(classifier, X_pool,
941+
most_certain_classes=np.random.randint(0, n_classes, size=(n_instances)))
942+
943+
925944
class TestExamples(unittest.TestCase):
926945

927946
def test_examples(self):

0 commit comments

Comments
 (0)