44from sklearn .multiclass import OneVsRestClassifier
55
66from modAL .utils .data import modALinput
7+ from modAL .utils .selection import multi_argmax
78from typing import Tuple , Optional
89from itertools import combinations
910
@@ -82,25 +83,13 @@ def max_loss(classifier: BaseEstimator,
8283 The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
8384 """
8485
85- most_certain_classes = classifier .predict_proba (X_pool ).argmax (axis = 1 )
86- loss = _SVM_loss (classifier , X_pool , most_certain_classes = most_certain_classes )
87-
8886 assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
8987
90- if n_instances == 1 :
91- query_idx = np .argmax (loss )
92- return query_idx , X_pool [query_idx ]
93- else :
94- max_val = - np .inf
95- max_idx = None
96- for subset_idx in combinations (range (len (X_pool )), n_instances ):
97- subset_sum = loss [list (subset_idx )].sum ()
98- if subset_sum > max_val :
99- max_val = subset_sum
100- max_idx = subset_idx
88+ most_certain_classes = classifier .predict_proba (X_pool ).argmax (axis = 1 )
89+ loss = _SVM_loss (classifier , X_pool , most_certain_classes = most_certain_classes )
10190
102- query_idx = np . array ( max_idx )
103- return query_idx , X_pool [query_idx ]
91+ query_idx = multi_argmax ( loss , n_instances )
92+ return query_idx , X_pool [query_idx ]
10493
10594
10695def mean_max_loss (classifier : BaseEstimator ,
@@ -123,21 +112,8 @@ def mean_max_loss(classifier: BaseEstimator,
123112 The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
124113 """
125114
126- loss = _SVM_loss (classifier , X_pool )
127-
128115 assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
116+ loss = _SVM_loss (classifier , X_pool )
129117
130- if n_instances == 1 :
131- query_idx = np .argmax (loss )
132- return query_idx , X_pool [query_idx ]
133- else :
134- max_val = - np .inf
135- max_idx = None
136- for subset_idx in combinations (range (len (X_pool )), n_instances ):
137- subset_sum = loss [list (subset_idx )].sum ()
138- if subset_sum > max_val :
139- max_val = subset_sum
140- max_idx = subset_idx
141-
142- query_idx = np .array (max_idx )
143- return query_idx , X_pool [query_idx ]
118+ query_idx = multi_argmax (loss , n_instances )
119+ return query_idx , X_pool [query_idx ]
0 commit comments