Skip to content

Commit 20000b8

Browse files
committed
add: multilabel SVM binary maximum strategy added
1 parent b9c833f commit 20000b8

File tree

4 files changed

+91
-0
lines changed

4 files changed

+91
-0
lines changed

examples/multilabel_svm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
from modAL.models import ActiveLearner
5+
from modAL.multilabel import SVM_binary_minimum
6+
7+
from sklearn.multiclass import OneVsRestClassifier
8+
from sklearn.svm import LinearSVC
9+
10+
n_samples = 500
11+
X = np.random.normal(size=(n_samples, 2))
12+
y = np.array([[int(x1 > 0), int(x2 > 0)] for x1, x2 in X])
13+
14+
n_initial = 10
15+
initial_idx = np.random.choice(range(len(X)), size=n_initial, replace=False)
16+
X_initial, y_initial = X[initial_idx], y[initial_idx]
17+
X_pool, y_pool = np.delete(X, initial_idx, axis=0), np.delete(y, initial_idx, axis=0)
18+
19+
with plt.style.context('seaborn-white'):
20+
plt.figure(figsize=(10, 10))
21+
plt.scatter(X[:, 0], X[:, 1], c='k', s=20)
22+
plt.scatter(X[y[:, 0] == 1, 0], X[y[:, 0] == 1, 1],
23+
facecolors='none', edgecolors='b', s=50, linewidths=2, label='class 1')
24+
plt.scatter(X[y[:, 1] == 1, 0], X[y[:, 1] == 1, 1],
25+
facecolors='none', edgecolors='r', s=100, linewidths=2, label='class 2')
26+
plt.legend()
27+
plt.show()
28+
29+
learner = ActiveLearner(
30+
estimator=OneVsRestClassifier(LinearSVC()),
31+
query_strategy=SVM_binary_minimum,
32+
X_training=X_initial, y_training=y_initial
33+
)
34+
35+
learner.query(X_pool)

modAL/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .models import ActiveLearner, Committee, CommitteeRegressor
2+
3+
__all__ = ['ActiveLearner', 'Committee', 'CommitteeRegressor']

modAL/multilabel.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
3+
from sklearn.base import BaseEstimator
4+
5+
from modAL.utils.data import modALinput
6+
from typing import Tuple
7+
8+
9+
def SVM_binary_minimum(classifier: BaseEstimator, X_pool: modALinput) -> Tuple[np.ndarray, modALinput]:
10+
"""
11+
SVM binary minimum multilabel active learning strategy. For details see the paper
12+
Klaus Brinker, On Active Learning in Multi-label Classification
13+
(https://link.springer.com/chapter/10.1007%2F3-540-31314-1_24)
14+
15+
Args:
16+
classifier: The multilabel classifier for which the labels are to be queried. Must be an SVM model
17+
such as the ones from sklearn.svm.
18+
X: The pool of samples to query from.
19+
20+
Returns:
21+
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
22+
"""
23+
min_abs_dist = np.min(np.abs(classifier.estimator.decision_function(X_pool)), axis=1)
24+
query_idx = np.argmin(min_abs_dist)
25+
return query_idx, X_pool[query_idx]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import numpy as np
2+
3+
from modAL.models import ActiveLearner
4+
from modAL.multilabel import SVM_binary_minimum
5+
6+
from sklearn.multiclass import OneVsRestClassifier
7+
from sklearn.svm import LinearSVC
8+
9+
n_samples = 500
10+
X = np.random.normal(size=(n_samples, 2))
11+
y = np.array([[int(x1 > 0), int(x2 > 0)] for x1, x2 in X])
12+
13+
n_initial = 10
14+
initial_idx = np.random.choice(range(len(X)), size=n_initial, replace=False)
15+
X_initial, y_initial = X[initial_idx], y[initial_idx]
16+
X_pool, y_pool = np.delete(X, initial_idx, axis=0), np.delete(y, initial_idx, axis=0)
17+
18+
learner = ActiveLearner(
19+
estimator=OneVsRestClassifier(LinearSVC()),
20+
query_strategy=SVM_binary_minimum,
21+
X_training=X_initial, y_training=y_initial
22+
)
23+
24+
n_queries = 10
25+
for idx in range(n_queries):
26+
query_idx, query_inst = learner.query(X_pool)
27+
learner.teach(X_pool[query_idx].reshape(1, -1), y_pool[query_idx].reshape(1, -1))
28+
X_pool, y_pool = np.delete(X_pool, query_idx, axis=0), np.delete(y_pool, query_idx, axis=0)

0 commit comments

Comments
 (0)