22Expected error reduction framework for active learning.
33"""
44
5- from typing import Tuple
5+ from typing import Tuple , Callable
66
77import numpy as np
88
1414from modAL .models import ActiveLearner
1515from modAL .utils .data import modALinput , data_vstack
1616from modAL .utils .selection import multi_argmax
17+ from modAL .uncertainty import _proba_uncertainty , _proba_entropy
1718
1819
19- def expected_error_reduction (learner : ActiveLearner , X : modALinput ,
20+ def expected_error_reduction (learner : ActiveLearner , X : modALinput , loss : str = 'binary' ,
2021 p_subsample : np .float = 1.0 , n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
2122 """
2223 Expected error reduction query strategy.
@@ -25,18 +26,23 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput,
2526 Roy and McCallum, 2001 (http://groups.csail.mit.edu/rrg/papers/icml01.pdf)
2627
2728 Args:
28- learner: The ActiveLearner object for which the expected error is to be estimated.
29+ learner: The ActiveLearner object for which the expected error
30+ is to be estimated.
2931 X: The samples.
30- p_subsample: Probability of keeping a sample from the pool when calculating expected error.
31- Significantly improves runtime for large sample pools.
32+ loss: The loss function to be used. Can be 'binary' or 'log'.
33+ p_subsample: Probability of keeping a sample from the pool when
34+ calculating expected error. Significantly improves runtime
35+ for large sample pools.
3236 n_instances: The number of instances to be sampled.
3337
3438
3539 Returns:
36- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
40+ The indices of the instances from X chosen to be labelled;
41+ the instances from X chosen to be labelled.
3742 """
3843
3944 assert 0.0 <= p_subsample <= 1.0 , 'p_subsample subsampling keep ratio must be between 0.0 and 1.0'
45+ assert loss in ['binary' , 'log' ], 'loss must be \' binary\' or \' log\' '
4046
4147 expected_error = np .zeros (shape = (len (X ), ))
4248 possible_labels = np .unique (learner .y_training )
@@ -56,66 +62,17 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput,
5662 y_new = data_vstack ((learner .y_training , np .array (y ).reshape (1 , )))
5763
5864 refitted_estimator = clone (learner .estimator ).fit (X_new , y_new )
59- uncertainty = 1 - np .max (refitted_estimator .predict_proba (X ), axis = 1 )
65+ refitted_proba = refitted_estimator .predict_proba (X )
66+ if loss is 'binary' :
67+ loss = _proba_uncertainty (refitted_proba )
68+ elif loss is 'log' :
69+ loss = _proba_entropy (refitted_proba )
6070
61- expected_error [x_idx ] += np .sum (uncertainty )* X_proba [x_idx , y_idx ]
71+ expected_error [x_idx ] += np .sum (loss )* X_proba [x_idx , y_idx ]
6272
6373 else :
6474 expected_error [x_idx ] = np .inf
6575
6676 query_idx = multi_argmax (expected_error , n_instances )
6777
6878 return query_idx , X [query_idx ]
69-
70-
71- def expected_log_loss_reduction (learner : ActiveLearner , X : modALinput ,
72- p_subsample : np .float = 1.0 , n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
73- """
74- Expected log loss reduction query strategy.
75-
76- References:
77- Roy and McCallum, 2001 (http://groups.csail.mit.edu/rrg/papers/icml01.pdf)
78-
79- Args:
80- learner: The ActiveLearner object for which the expected log loss is to be estimated.
81- X: The samples.
82- p_subsample: Probability of keeping a sample from the pool when calculating expected log loss.
83- Significantly improves runtime for large sample pools.
84- n_instances: The number of instances to be sampled.
85-
86-
87- Returns:
88- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
89- """
90-
91- assert 0.0 <= p_subsample <= 1.0 , 'p_subsample subsampling keep ratio must be between 0.0 and 1.0'
92-
93- expected_log_loss = np .zeros (shape = (len (X ), ))
94- possible_labels = np .unique (learner .y_training )
95-
96- try :
97- X_proba = learner .predict_proba (X )
98- except NotFittedError :
99- # TODO: implement a proper cold-start
100- return 0 , X [0 ]
101-
102- for x_idx , x in enumerate (X ):
103- # subsample the data if needed
104- if np .random .rand () <= p_subsample :
105- # estimate the expected error
106- for y_idx , y in enumerate (possible_labels ):
107- X_new = data_vstack ((learner .X_training , x .reshape (1 , - 1 )))
108- y_new = data_vstack ((learner .y_training , np .array (y ).reshape (1 , )))
109-
110- refitted_estimator = clone (learner .estimator ).fit (X_new , y_new )
111- refitted_proba = refitted_estimator .predict_proba (X )
112- entr = np .transpose (entropy (np .transpose (refitted_proba )))
113-
114- expected_log_loss [x_idx ] += np .sum (entr )* X_proba [x_idx , y_idx ]
115-
116- else :
117- expected_log_loss [x_idx ] = np .inf
118-
119- query_idx = multi_argmax (expected_log_loss , n_instances )
120-
121- return query_idx , X [query_idx ]
0 commit comments