Skip to content

Commit c6f8d7b

Browse files
committed
refactor: max_loss and mean_max_loss multilabel strategies refactored
1 parent f2cf52e commit c6f8d7b

File tree

1 file changed

+8
-32
lines changed

1 file changed

+8
-32
lines changed

modAL/multilabel.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sklearn.multiclass import OneVsRestClassifier
55

66
from modAL.utils.data import modALinput
7+
from modAL.utils.selection import multi_argmax
78
from typing import Tuple, Optional
89
from itertools import combinations
910

@@ -82,25 +83,13 @@ def max_loss(classifier: BaseEstimator,
8283
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
8384
"""
8485

85-
most_certain_classes = classifier.predict_proba(X_pool).argmax(axis=1)
86-
loss = _SVM_loss(classifier, X_pool, most_certain_classes=most_certain_classes)
87-
8886
assert len(X_pool) >= n_instances, 'n_instances cannot be larger than len(X_pool)'
8987

90-
if n_instances == 1:
91-
query_idx = np.argmax(loss)
92-
return query_idx, X_pool[query_idx]
93-
else:
94-
max_val = -np.inf
95-
max_idx = None
96-
for subset_idx in combinations(range(len(X_pool)), n_instances):
97-
subset_sum = loss[list(subset_idx)].sum()
98-
if subset_sum > max_val:
99-
max_val = subset_sum
100-
max_idx = subset_idx
88+
most_certain_classes = classifier.predict_proba(X_pool).argmax(axis=1)
89+
loss = _SVM_loss(classifier, X_pool, most_certain_classes=most_certain_classes)
10190

102-
query_idx = np.array(max_idx)
103-
return query_idx, X_pool[query_idx]
91+
query_idx = multi_argmax(loss, n_instances)
92+
return query_idx, X_pool[query_idx]
10493

10594

10695
def mean_max_loss(classifier: BaseEstimator,
@@ -123,21 +112,8 @@ def mean_max_loss(classifier: BaseEstimator,
123112
The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
124113
"""
125114

126-
loss = _SVM_loss(classifier, X_pool)
127-
128115
assert len(X_pool) >= n_instances, 'n_instances cannot be larger than len(X_pool)'
116+
loss = _SVM_loss(classifier, X_pool)
129117

130-
if n_instances == 1:
131-
query_idx = np.argmax(loss)
132-
return query_idx, X_pool[query_idx]
133-
else:
134-
max_val = -np.inf
135-
max_idx = None
136-
for subset_idx in combinations(range(len(X_pool)), n_instances):
137-
subset_sum = loss[list(subset_idx)].sum()
138-
if subset_sum > max_val:
139-
max_val = subset_sum
140-
max_idx = subset_idx
141-
142-
query_idx = np.array(max_idx)
143-
return query_idx, X_pool[query_idx]
118+
query_idx = multi_argmax(loss, n_instances)
119+
return query_idx, X_pool[query_idx]

0 commit comments

Comments
 (0)