55
66from modAL .models import ActiveLearner
77from modAL .utils .data import modALinput
8- from modAL .utils .selection import multi_argmax
8+ from modAL .utils .selection import multi_argmax , shuffled_argmax
99from typing import Tuple , Optional
1010from itertools import combinations
1111
1212
1313def _SVM_loss (multiclass_classifier : ActiveLearner ,
14- X : modALinput ,
15- most_certain_classes : Optional [int ] = None ) -> np .ndarray :
14+ X : modALinput , most_certain_classes : Optional [int ] = None ) -> np .ndarray :
1615 """
1716 Utility function for max_loss and mean_max_loss strategies.
1817
@@ -43,8 +42,8 @@ def _SVM_loss(multiclass_classifier: ActiveLearner,
4342 return cls_loss
4443
4544
46- def SVM_binary_minimum (classifier : ActiveLearner ,
47- X_pool : modALinput ) -> Tuple [np .ndarray , modALinput ]:
45+ def SVM_binary_minimum (classifier : ActiveLearner , X_pool : modALinput ,
46+ random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
4847 """
4948 SVM binary minimum multilabel active learning strategy. For details see the paper
5049 Klaus Brinker, On Active Learning in Multi-label Classification
@@ -53,23 +52,30 @@ def SVM_binary_minimum(classifier: ActiveLearner,
5352 Args:
5453 classifier: The multilabel classifier for which the labels are to be queried. Must be an SVM model
5554 such as the ones from sklearn.svm.
56- X: The pool of samples to query from.
55+ X_pool: The pool of samples to query from.
56+ random_tie_break: If True, shuffles utility scores to randomize the order. This
57+ can be used to break the tie when the highest utility score is not unique.
5758
5859 Returns:
59- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
60+ The index of the instance from X_pool chosen to be labelled;
61+ the instance from X_pool chosen to be labelled.
6062 """
6163
6264 decision_function = np .array ([svm .decision_function (X_pool )
6365 for svm in classifier .estimator .estimators_ ]).T
6466
6567 min_abs_dist = np .min (np .abs (decision_function ), axis = 1 )
66- query_idx = np .argmin (min_abs_dist )
68+
69+ if not random_tie_break :
70+ query_idx = np .argmin (min_abs_dist )
71+ else :
72+ query_idx = shuffled_argmax (min_abs_dist )
73+
6774 return query_idx , X_pool [query_idx ]
6875
6976
70- def max_loss (classifier : OneVsRestClassifier ,
71- X_pool : modALinput ,
72- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
77+ def max_loss (classifier : OneVsRestClassifier , X_pool : modALinput ,
78+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
7379
7480 """
7581 Max Loss query strategy for SVM multilabel classification.
@@ -82,24 +88,30 @@ def max_loss(classifier: OneVsRestClassifier,
8288 classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
8389 such as the ones from sklearn.svm. Although the function will execute for other models as well,
8490 the mathematical calculations in Li et al. work only for SVM-s.
85- X: The pool of samples to query from.
91+ X_pool: The pool of samples to query from.
92+ random_tie_break: If True, shuffles utility scores to randomize the order. This
93+ can be used to break the tie when the highest utility score is not unique.
8694
8795 Returns:
88- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
96+ The index of the instance from X_pool chosen to be labelled;
97+ the instance from X_pool chosen to be labelled.
8998 """
9099
91100 assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
92101
93102 most_certain_classes = classifier .predict_proba (X_pool ).argmax (axis = 1 )
94103 loss = _SVM_loss (classifier , X_pool , most_certain_classes = most_certain_classes )
95104
96- query_idx = multi_argmax (loss , n_instances )
105+ if not random_tie_break :
106+ query_idx = multi_argmax (loss , n_instances )
107+ else :
108+ query_idx = shuffled_argmax (loss , n_instances )
109+
97110 return query_idx , X_pool [query_idx ]
98111
99112
100- def mean_max_loss (classifier : OneVsRestClassifier ,
101- X_pool : modALinput ,
102- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
113+ def mean_max_loss (classifier : OneVsRestClassifier , X_pool : modALinput ,
114+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
103115 """
104116 Mean Max Loss query strategy for SVM multilabel classification.
105117
@@ -111,22 +123,28 @@ def mean_max_loss(classifier: OneVsRestClassifier,
111123 classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
112124 such as the ones from sklearn.svm. Although the function will execute for other models as well,
113125 the mathematical calculations in Li et al. work only for SVM-s.
114- X: The pool of samples to query from.
126+ X_pool: The pool of samples to query from.
127+ random_tie_break: If True, shuffles utility scores to randomize the order. This
128+ can be used to break the tie when the highest utility score is not unique.
115129
116130 Returns:
117- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
131+ The index of the instance from X_pool chosen to be labelled;
132+ the instance from X_pool chosen to be labelled.
118133 """
119134
120135 assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
121136 loss = _SVM_loss (classifier , X_pool )
122137
123- query_idx = multi_argmax (loss , n_instances )
138+ if not random_tie_break :
139+ query_idx = multi_argmax (loss , n_instances )
140+ else :
141+ query_idx = shuffled_argmax (loss , n_instances )
142+
124143 return query_idx , X_pool [query_idx ]
125144
126145
127- def min_confidence (classifier : OneVsRestClassifier ,
128- X_pool : modALinput ,
129- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
146+ def min_confidence (classifier : OneVsRestClassifier , X_pool : modALinput ,
147+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
130148 """
131149 MinConfidence query strategy for multilabel classification.
132150
@@ -136,22 +154,28 @@ def min_confidence(classifier: OneVsRestClassifier,
136154
137155 Args:
138156 classifier: The multilabel classifier for which the labels are to be queried.
139- X: The pool of samples to query from.
157+ X_pool: The pool of samples to query from.
158+ random_tie_break: If True, shuffles utility scores to randomize the order. This
159+ can be used to break the tie when the highest utility score is not unique.
140160
141161 Returns:
142- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
162+ The index of the instance from X_pool chosen to be labelled;
163+ the instance from X_pool chosen to be labelled.
143164 """
144165
145166 classwise_confidence = classifier .predict_proba (X_pool )
146167 classwise_min = np .min (classwise_confidence , axis = 1 )
147- query_idx = multi_argmax ((- 1 )* classwise_min , n_instances )
168+
169+ if not random_tie_break :
170+ query_idx = multi_argmax (- classwise_min , n_instances )
171+ else :
172+ query_idx = shuffled_argmax (- classwise_min , n_instances )
148173
149174 return query_idx , X_pool [query_idx ]
150175
151176
152- def avg_confidence (classifier : OneVsRestClassifier ,
153- X_pool : modALinput ,
154- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
177+ def avg_confidence (classifier : OneVsRestClassifier , X_pool : modALinput ,
178+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
155179 """
156180 AvgConfidence query strategy for multilabel classification.
157181
@@ -161,22 +185,28 @@ def avg_confidence(classifier: OneVsRestClassifier,
161185
162186 Args:
163187 classifier: The multilabel classifier for which the labels are to be queried.
164- X: The pool of samples to query from.
188+ X_pool: The pool of samples to query from.
189+ random_tie_break: If True, shuffles utility scores to randomize the order. This
190+ can be used to break the tie when the highest utility score is not unique.
165191
166192 Returns:
167- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
193+ The index of the instance from X_pool chosen to be labelled;
194+ the instance from X_pool chosen to be labelled.
168195 """
169196
170197 classwise_confidence = classifier .predict_proba (X_pool )
171198 classwise_mean = np .mean (classwise_confidence , axis = 1 )
172- query_idx = multi_argmax (classwise_mean , n_instances )
199+
200+ if not random_tie_break :
201+ query_idx = multi_argmax (classwise_mean , n_instances )
202+ else :
203+ query_idx = shuffled_argmax (classwise_mean , n_instances )
173204
174205 return query_idx , X_pool [query_idx ]
175206
176207
177- def max_score (classifier : OneVsRestClassifier ,
178- X_pool : modALinput ,
179- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
208+ def max_score (classifier : OneVsRestClassifier , X_pool : modALinput ,
209+ n_instances : int = 1 , random_tie_break : bool = 1 ) -> Tuple [np .ndarray , modALinput ]:
180210 """
181211 MaxScore query strategy for multilabel classification.
182212
@@ -186,24 +216,30 @@ def max_score(classifier: OneVsRestClassifier,
186216
187217 Args:
188218 classifier: The multilabel classifier for which the labels are to be queried.
189- X: The pool of samples to query from.
219+ X_pool: The pool of samples to query from.
220+ random_tie_break: If True, shuffles utility scores to randomize the order. This
221+ can be used to break the tie when the highest utility score is not unique.
190222
191223 Returns:
192- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
224+ The index of the instance from X_pool chosen to be labelled;
225+ the instance from X_pool chosen to be labelled.
193226 """
194227
195228 classwise_confidence = classifier .predict_proba (X_pool )
196229 classwise_predictions = classifier .predict (X_pool )
197230 classwise_scores = classwise_confidence * (classwise_predictions - 1 / 2 )
198231 classwise_max = np .max (classwise_scores , axis = 1 )
199- query_idx = multi_argmax (classwise_max , n_instances )
232+
233+ if not random_tie_break :
234+ query_idx = multi_argmax (classwise_max , n_instances )
235+ else :
236+ query_idx = shuffled_argmax (classwise_max , n_instances )
200237
201238 return query_idx , X_pool [query_idx ]
202239
203240
204- def avg_score (classifier : OneVsRestClassifier ,
205- X_pool : modALinput ,
206- n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
241+ def avg_score (classifier : OneVsRestClassifier , X_pool : modALinput ,
242+ n_instances : int = 1 , random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
207243 """
208244 AvgScore query strategy for multilabel classification.
209245
@@ -213,16 +249,23 @@ def avg_score(classifier: OneVsRestClassifier,
213249
214250 Args:
215251 classifier: The multilabel classifier for which the labels are to be queried.
216- X: The pool of samples to query from.
252+ X_pool: The pool of samples to query from.
253+ random_tie_break: If True, shuffles utility scores to randomize the order. This
254+ can be used to break the tie when the highest utility score is not unique.
217255
218256 Returns:
219- The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
257+ The index of the instance from X_pool chosen to be labelled;
258+ the instance from X_pool chosen to be labelled.
220259 """
221260
222261 classwise_confidence = classifier .predict_proba (X_pool )
223262 classwise_predictions = classifier .predict (X_pool )
224263 classwise_scores = classwise_confidence * (classwise_predictions - 1 / 2 )
225264 classwise_mean = np .mean (classwise_scores , axis = 1 )
226- query_idx = multi_argmax (classwise_mean , n_instances )
265+
266+ if not random_tie_break :
267+ query_idx = multi_argmax (classwise_mean , n_instances )
268+ else :
269+ query_idx = shuffled_argmax (classwise_mean , n_instances )
227270
228271 return query_idx , X_pool [query_idx ]
0 commit comments