Skip to content

Commit 60d5546

Browse files
committed
add test for multilabel quire
1 parent 65791a4 commit 60d5546

File tree

3 files changed

+107
-32
lines changed

3 files changed

+107
-32
lines changed

libact/query_strategies/multilabel/multilabel_quire.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,10 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1.,
8181
X, _ = zip(*dataset.get_entries())
8282
self.kernel = kernel
8383
if self.kernel == 'rbf':
84-
self.K = rbf_kernel(X=X, Y=X, gamma=kwargs.pop('gamma', 1.))
84+
self.K = rbf_kernel(X=X, Y=X, gamma=gamma)
8585
elif self.kernel == 'poly':
86-
self.K = polynomial_kernel(X=X,
87-
Y=X,
88-
coef0=kwargs.pop('coef0', 1),
89-
degree=kwargs.pop('degree', 3),
90-
gamma=kwargs.pop('gamma', 1.))
86+
self.K = polynomial_kernel(X=X, Y=X, coef0=coef0, degree=degree,
87+
gamma=gamma)
9188
elif self.kernel == 'linear':
9289
self.K = linear_kernel(X=X, Y=X)
9390
elif hasattr(self.kernel, '__call__'):
@@ -99,48 +96,70 @@ def __init__(self, dataset, lamba=1.0, kernel='rbf', gamma=1., coef0=1.,
9996

10097

10198
_, lbled_Y = zip(*dataset.get_labeled_entries())
99+
self.n_labels = np.shape(lbled_Y)[1]
102100
n = len(X)
103-
m = np.shape(lbled_Y)[1]
101+
m = self.n_labels
104102
# label correlation matrix
105103
R = np.corrcoef(np.array(lbled_Y).T)
106104
R = np.nan_to_num(R)
107105
self.RK = np.kron(R, self.K)
108106

109107
self.L = lamba * (np.linalg.pinv(self.RK + lamba * np.eye(n*m)))
110108

111-
@inherit_docstring_from(QueryStrategy)
112-
def make_query(self):
113-
dataset = self.dataset
114-
X, Y = zip(*dataset.get_entries())
115-
_, lbled_Y = zip(*dataset.get_labeled_entries())
116-
117-
X = np.array(X)
118-
RK = self.RK
119-
n_instance = len(X)
120-
m = np.shape(lbled_Y)[1]
121-
lamba = self.lamba
122-
109+
def _get_index(self):
110+
_, Y = zip(*self.dataset.get_entries())
111+
n_instance = len(Y)
112+
m = self.n_labels
123113
# index for labeled and unlabeled instance
124114
l_id = []
125115
a_id = []
126116
for i in range(n_instance * m):
127-
if Y[i%n_instance] is None:
117+
if Y[i // m] is None:
128118
a_id.append(i)
129119
else:
130120
l_id.append(i)
121+
return a_id, l_id
122+
123+
#def update(self, entry_id, label):
124+
# # calculate invLaa
125+
# invLaa = self.invLaa
126+
# # idx before update
127+
# a_id, l_id = self.idxs
128+
# m = len(label)
129+
# # assert len(np.where(np.array(a_id) == entry_id*m)[0]) == 1
130+
# idx = np.where(np.array(a_id) == entry_id*m)[0][0]
131+
# for i in range(m):
132+
# D = np.delete(np.delete(invLaa, idx, axis=0), idx, axis=1)
133+
# b = np.delete(invLaa, idx, axis=0)[:, idx]
134+
# # invLuu
135+
# invLaa = D - 1./invLaa[idx, idx] * np.dot(b, b.T)
136+
# self.invLaa = invLaa
137+
138+
@inherit_docstring_from(QueryStrategy)
139+
def make_query(self):
140+
dataset = self.dataset
141+
X, Y = zip(*dataset.get_entries())
142+
X = np.array(X)
131143

144+
n_instance = len(X)
145+
m = self.n_labels
146+
RK = self.RK
147+
lamba = self.lamba
132148
L = self.L
133-
vecY = np.reshape(np.array([y for y in Y if y is not None]).T, (-1, 1))
134-
detLaa = np.linalg.det(L[np.ix_(a_id, a_id)])
135-
#invLaa = np.linalg.pinv(L[np.ix_(a_id, a_id)])
136-
invLaa = (lamba * np.eye(len(a_id)) + RK[np.ix_(a_id, a_id)]) \
149+
150+
a_id, l_id = self._get_index()
151+
# invLaa = np.linalg.pinv(L[np.ix_(a_id, a_id)])
152+
invLaa = ((lamba * np.eye(len(a_id)) + RK[np.ix_(a_id, a_id)]) \
137153
- np.dot(np.dot(RK[np.ix_(a_id, l_id)],
138154
np.linalg.pinv(lamba * np.eye(len(l_id)) \
139155
+ RK[np.ix_(l_id, l_id)])),
140-
RK[np.ix_(l_id, a_id)])
156+
RK[np.ix_(l_id, a_id)])) / lamba
157+
158+
vecY = np.reshape(np.array([y for y in Y if y is not None]).T, (-1, 1))
159+
detLaa = np.linalg.det(L[np.ix_(a_id, a_id)])
141160

161+
score = np.zeros(len(a_id))
142162
b = np.zeros((len(a_id)-1))
143-
score = []
144163
D = np.zeros((len(a_id)-1, len(a_id)-1))
145164
D[...] = invLaa[1:, 1:]
146165
for i, s in enumerate(a_id):
@@ -162,13 +181,13 @@ def make_query(self):
162181
b[i:] = invLaa[i+1:, i]
163182
invLuu = D - 1./invLaa[i, i] * np.dot(b, b.T)
164183

165-
score.append(L[s, s] - detLaa / L[s, s] \
166-
+ 2 * np.abs(np.dot(L[s, l_id] \
167-
- np.dot(np.dot(L[s, u_id], invLuu),
168-
L[np.ix_(u_id, l_id)]), vecY)))
184+
score[i] = L[s, s] - detLaa / L[s, s] \
185+
+ 2 * np.abs(np.dot(L[s, l_id] \
186+
- np.dot(np.dot(L[s, u_id], invLuu),
187+
L[np.ix_(u_id, l_id)]), vecY))
169188

170-
score = np.sum(np.array(score).reshape(m, -1).T, axis=1)
189+
score = np.sum(score.reshape(m, -1).T, axis=1)
171190

172191
ask_idx = self.random_state_.choice(np.where(score == np.min(score))[0])
173192

174-
return a_id[ask_idx]
193+
return a_id[ask_idx] // m
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import unittest
2+
3+
from numpy.testing import assert_array_equal
4+
import numpy as np
5+
6+
from libact.base.dataset import Dataset
7+
from libact.query_strategies.multilabel import MultilabelQUIRE
8+
from libact.utils import run_qs
9+
10+
11+
class MultilabelQUIRETestCase(unittest.TestCase):
12+
"""Variance reduction test case using artifitial dataset"""
13+
def setUp(self):
14+
self.X = [[-2, -1], [1, 1], [-1, -2], [-1, -1], [1, 2], [2, 1]]
15+
self.y = [[0, 1], [1, 0], [0, 1], [1, 0], [1, 0], [1, 1]]
16+
self.quota = 4
17+
18+
def test_multilabel_quire(self):
19+
trn_ds = Dataset(self.X, (self.y[:2] + [None] * (len(self.y) - 2)))
20+
qs = MultilabelQUIRE(trn_ds)
21+
qseq = run_qs(trn_ds, qs, self.y, self.quota)
22+
assert_array_equal(qseq, np.array([2, 3, 4, 5]))
23+
24+
25+
if __name__ == '__main__':
26+
unittest.main()

libact/utils/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,33 @@ def calc_cost(y, yhat, cost_matrix):
5050
ith class and prediction as jth class.
5151
"""
5252
return np.mean(cost_matrix[list(y), list(yhat)])
53+
54+
def run_qs(trn_ds, qs, truth, quota):
55+
"""Run query strategy on specified dataset and return quering sequence.
56+
57+
Parameters
58+
----------
59+
trn_ds : Dataset object
60+
The dataset to be run on.
61+
62+
qs : QueryStrategy instance
63+
The active learning algorith to be run.
64+
65+
truth : array-like
66+
The true label.
67+
68+
quota : int
69+
Number of iterations to run
70+
71+
Returns
72+
-------
73+
qseq : numpy array, shape (quota,)
74+
The numpy array of entry_id representing querying sequence.
75+
"""
76+
ret = []
77+
for _ in range(quota):
78+
ask_id = qs.make_query()
79+
trn_ds.update(ask_id, truth[ask_id])
80+
81+
ret.append(ask_id)
82+
return np.array(ret)

0 commit comments

Comments
 (0)