Skip to content

Commit 1a19391

Browse files
committed
refactor: expected error reduction and log loss reduction merged
1 parent 47fd12b commit 1a19391

File tree

2 files changed

+22
-67
lines changed

2 files changed

+22
-67
lines changed

modAL/expected_error.py

Lines changed: 18 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Expected error reduction framework for active learning.
33
"""
44

5-
from typing import Tuple
5+
from typing import Tuple, Callable
66

77
import numpy as np
88

@@ -14,9 +14,10 @@
1414
from modAL.models import ActiveLearner
1515
from modAL.utils.data import modALinput, data_vstack
1616
from modAL.utils.selection import multi_argmax
17+
from modAL.uncertainty import _proba_uncertainty, _proba_entropy
1718

1819

19-
def expected_error_reduction(learner: ActiveLearner, X: modALinput,
20+
def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str = 'binary',
2021
p_subsample: np.float = 1.0, n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
2122
"""
2223
Expected error reduction query strategy.
@@ -25,18 +26,23 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput,
2526
Roy and McCallum, 2001 (http://groups.csail.mit.edu/rrg/papers/icml01.pdf)
2627
2728
Args:
28-
learner: The ActiveLearner object for which the expected error is to be estimated.
29+
learner: The ActiveLearner object for which the expected error
30+
is to be estimated.
2931
X: The samples.
30-
p_subsample: Probability of keeping a sample from the pool when calculating expected error.
31-
Significantly improves runtime for large sample pools.
32+
loss: The loss function to be used. Can be 'binary' or 'log'.
33+
p_subsample: Probability of keeping a sample from the pool when
34+
calculating expected error. Significantly improves runtime
35+
for large sample pools.
3236
n_instances: The number of instances to be sampled.
3337
3438
3539
Returns:
36-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
40+
The indices of the instances from X chosen to be labelled;
41+
the instances from X chosen to be labelled.
3742
"""
3843

3944
assert 0.0 <= p_subsample <= 1.0, 'p_subsample subsampling keep ratio must be between 0.0 and 1.0'
45+
assert loss in ['binary', 'log'], 'loss must be \'binary\' or \'log\''
4046

4147
expected_error = np.zeros(shape=(len(X), ))
4248
possible_labels = np.unique(learner.y_training)
@@ -56,66 +62,17 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput,
5662
y_new = data_vstack((learner.y_training, np.array(y).reshape(1, )))
5763

5864
refitted_estimator = clone(learner.estimator).fit(X_new, y_new)
59-
uncertainty = 1 - np.max(refitted_estimator.predict_proba(X), axis=1)
65+
refitted_proba = refitted_estimator.predict_proba(X)
66+
if loss is 'binary':
67+
loss = _proba_uncertainty(refitted_proba)
68+
elif loss is 'log':
69+
loss = _proba_entropy(refitted_proba)
6070

61-
expected_error[x_idx] += np.sum(uncertainty)*X_proba[x_idx, y_idx]
71+
expected_error[x_idx] += np.sum(loss)*X_proba[x_idx, y_idx]
6272

6373
else:
6474
expected_error[x_idx] = np.inf
6575

6676
query_idx = multi_argmax(expected_error, n_instances)
6777

6878
return query_idx, X[query_idx]
69-
70-
71-
def expected_log_loss_reduction(learner: ActiveLearner, X: modALinput,
72-
p_subsample: np.float = 1.0, n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
73-
"""
74-
Expected log loss reduction query strategy.
75-
76-
References:
77-
Roy and McCallum, 2001 (http://groups.csail.mit.edu/rrg/papers/icml01.pdf)
78-
79-
Args:
80-
learner: The ActiveLearner object for which the expected log loss is to be estimated.
81-
X: The samples.
82-
p_subsample: Probability of keeping a sample from the pool when calculating expected log loss.
83-
Significantly improves runtime for large sample pools.
84-
n_instances: The number of instances to be sampled.
85-
86-
87-
Returns:
88-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
89-
"""
90-
91-
assert 0.0 <= p_subsample <= 1.0, 'p_subsample subsampling keep ratio must be between 0.0 and 1.0'
92-
93-
expected_log_loss = np.zeros(shape=(len(X), ))
94-
possible_labels = np.unique(learner.y_training)
95-
96-
try:
97-
X_proba = learner.predict_proba(X)
98-
except NotFittedError:
99-
# TODO: implement a proper cold-start
100-
return 0, X[0]
101-
102-
for x_idx, x in enumerate(X):
103-
# subsample the data if needed
104-
if np.random.rand() <= p_subsample:
105-
# estimate the expected error
106-
for y_idx, y in enumerate(possible_labels):
107-
X_new = data_vstack((learner.X_training, x.reshape(1, -1)))
108-
y_new = data_vstack((learner.y_training, np.array(y).reshape(1, )))
109-
110-
refitted_estimator = clone(learner.estimator).fit(X_new, y_new)
111-
refitted_proba = refitted_estimator.predict_proba(X)
112-
entr = np.transpose(entropy(np.transpose(refitted_proba)))
113-
114-
expected_log_loss[x_idx] += np.sum(entr)*X_proba[x_idx, y_idx]
115-
116-
else:
117-
expected_log_loss[x_idx] = np.inf
118-
119-
query_idx = multi_argmax(expected_log_loss, n_instances)
120-
121-
return query_idx, X[query_idx]

tests/core_tests.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -393,16 +393,14 @@ def test_eer(self):
393393
learner = modAL.models.ActiveLearner(RandomForestClassifier(n_estimators=2),
394394
X_training=X_training, y_training=y_training)
395395

396-
modAL.expected_error.expected_log_loss_reduction(learner, X_pool)
397396
modAL.expected_error.expected_error_reduction(learner, X_pool)
398-
modAL.expected_error.expected_log_loss_reduction(learner, X_pool, p_subsample=0.1)
399397
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1)
400-
modAL.expected_error.expected_log_loss_reduction(learner, X_pool)
401-
modAL.expected_error.expected_error_reduction(learner, X_pool)
398+
modAL.expected_error.expected_error_reduction(learner, X_pool, loss='binary')
399+
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1, loss='log')
402400
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
403401
learner, X_pool, p_subsample=1.5)
404-
self.assertRaises(AssertionError, modAL.expected_error.expected_log_loss_reduction,
405-
learner, X_pool, p_subsample=1.5)
402+
self.assertRaises(AssertionError, modAL.expected_error.expected_error_reduction,
403+
learner, X_pool, loss=42)
406404

407405

408406
class TestUncertainties(unittest.TestCase):

0 commit comments

Comments
 (0)