55
66from modAL .utils .data import modALinput
77from typing import Tuple , Optional
8+ from itertools import combinations
89
910
1011def _SVM_loss (multiclass_classifier : OneVsRestClassifier ,
@@ -30,7 +31,7 @@ def _SVM_loss(multiclass_classifier: OneVsRestClassifier,
3031 if most_certain_classes is None :
3132 cls_mtx = 2 * np .eye (n_classes , n_classes ) - 1
3233 loss_mtx = np .maximum (1 - np .dot (predictions , cls_mtx ), 0 )
33- return loss_mtx .mean (axis = 0 )
34+ return loss_mtx .mean (axis = 1 )
3435 else :
3536 cls_mtx = - np .ones (shape = (len (X ), n_classes ))
3637 for inst_idx , most_certain_class in enumerate (most_certain_classes ):
@@ -63,10 +64,80 @@ def SVM_binary_minimum(classifier: BaseEstimator,
6364def max_loss (classifier : BaseEstimator ,
6465 X_pool : modALinput ,
6566 n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
66- pass
67+
68+ """
69+ Max Loss query strategy for SVM multilabel classification.
70+
71+ For more details on this query strategy, see
72+ Li et al., Multilabel SVM active learning for image classification
73+ (http://dx.doi.org/10.1109/ICIP.2004.1421535)
74+
75+ Args:
76+ classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
77+ such as the ones from sklearn.svm. Although the function will execute for other models as well,
78+ the mathematical calculations in Li et al. work only for SVM-s.
79+ X: The pool of samples to query from.
80+
81+ Returns:
82+ The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
83+ """
84+
85+ most_certain_classes = classifier .predict_proba (X_pool ).argmax (axis = 1 )
86+ loss = _SVM_loss (classifier , X_pool , most_certain_classes = most_certain_classes )
87+
88+ assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
89+
90+ if n_instances == 1 :
91+ query_idx = np .argmax (loss )
92+ return query_idx , X_pool [query_idx ]
93+ else :
94+ max_val = - np .inf
95+ max_idx = None
96+ for subset_idx in combinations (range (len (X_pool )), n_instances ):
97+ subset_sum = loss [list (subset_idx )].sum ()
98+ if subset_sum > max_val :
99+ max_val = subset_sum
100+ max_idx = subset_idx
101+
102+ query_idx = np .array (max_idx )
103+ return query_idx , X_pool [query_idx ]
67104
68105
69106def mean_max_loss (classifier : BaseEstimator ,
70107 X_pool : modALinput ,
71108 n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
72- pass
109+ """
110+ Mean Max Loss query strategy for SVM multilabel classification.
111+
112+ For more details on this query strategy, see
113+ Li et al., Multilabel SVM active learning for image classification
114+ (http://dx.doi.org/10.1109/ICIP.2004.1421535)
115+
116+ Args:
117+ classifier: The multilabel classifier for which the labels are to be queried. Should be an SVM model
118+ such as the ones from sklearn.svm. Although the function will execute for other models as well,
119+ the mathematical calculations in Li et al. work only for SVM-s.
120+ X: The pool of samples to query from.
121+
122+ Returns:
123+ The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
124+ """
125+
126+ loss = _SVM_loss (classifier , X_pool )
127+
128+ assert len (X_pool ) >= n_instances , 'n_instances cannot be larger than len(X_pool)'
129+
130+ if n_instances == 1 :
131+ query_idx = np .argmax (loss )
132+ return query_idx , X_pool [query_idx ]
133+ else :
134+ max_val = - np .inf
135+ max_idx = None
136+ for subset_idx in combinations (range (len (X_pool )), n_instances ):
137+ subset_sum = loss [list (subset_idx )].sum ()
138+ if subset_sum > max_val :
139+ max_val = subset_sum
140+ max_idx = subset_idx
141+
142+ query_idx = np .array (max_idx )
143+ return query_idx , X_pool [query_idx ]
0 commit comments