11import numpy as np
22
33from sklearn .base import BaseEstimator
4+ from sklearn .multiclass import OneVsRestClassifier
45
56from modAL .utils .data import modALinput
6- from typing import Tuple
7+ from typing import Tuple , Optional
78
89
9- def SVM_binary_minimum (classifier : BaseEstimator , X_pool : modALinput ) -> Tuple [np .ndarray , modALinput ]:
10+ def _SVM_loss (multiclass_classifier : OneVsRestClassifier ,
11+ X : modALinput ,
12+ most_certain_classes : Optional [int ] = None ) -> np .ndarray :
13+ """
14+ Utility function for max_loss and mean_max_loss strategies.
15+
16+ Args:
17+ multiclass_classifier: sklearn.multiclass.OneVsRestClassifier instance for which the loss
18+ is to be calculated.
19+ X: The pool of samples to query from.
20+ most_certain_classes: optional, indexes of most certainly predicted class for each instance.
21+ If None, loss is calculated for all classes.
22+
23+ Returns:
24+ np.ndarray of shape (n_instances, ), losses for the instances in X.
25+
26+ """
27+ predictions = 2 * multiclass_classifier .predict (X )- 1
28+ n_classes = len (multiclass_classifier .classes_ )
29+
30+ if most_certain_classes is None :
31+ cls_mtx = 2 * np .eye (n_classes , n_classes ) - 1
32+ loss_mtx = np .maximum (1 - np .dot (predictions , cls_mtx ), 0 )
33+ return loss_mtx .mean (axis = 0 )
34+ else :
35+ cls_mtx = - np .ones (shape = (len (X ), n_classes ))
36+ for inst_idx , most_certain_class in enumerate (most_certain_classes ):
37+ cls_mtx [inst_idx , most_certain_class ] = 1
38+
39+ cls_loss = np .maximum (1 - np .multiply (cls_mtx , predictions ), 0 ).sum (axis = 1 )
40+ return cls_loss
41+
42+
43+ def SVM_binary_minimum (classifier : BaseEstimator ,
44+ X_pool : modALinput ) -> Tuple [np .ndarray , modALinput ]:
1045 """
1146 SVM binary minimum multilabel active learning strategy. For details see the paper
1247 Klaus Brinker, On Active Learning in Multi-label Classification
1348 (https://link.springer.com/chapter/10.1007%2F3-540-31314-1_24)
1449
1550 Args:
1651 classifier: The multilabel classifier for which the labels are to be queried. Must be an SVM model
17- such as the ones from sklearn.svm.
52+ such as the ones from sklearn.svm.
1853 X: The pool of samples to query from.
1954
2055 Returns:
2156 The index of the instance from X chosen to be labelled; the instance from X chosen to be labelled.
2257 """
2358 min_abs_dist = np .min (np .abs (classifier .estimator .decision_function (X_pool )), axis = 1 )
2459 query_idx = np .argmin (min_abs_dist )
25- return query_idx , X_pool [query_idx ]
60+ return query_idx , X_pool [query_idx ]
61+
62+
63+ def max_loss (classifier : BaseEstimator ,
64+ X_pool : modALinput ,
65+ n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
66+ pass
67+
68+
69+ def mean_max_loss (classifier : BaseEstimator ,
70+ X_pool : modALinput ,
71+ n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
72+ pass
0 commit comments