Skip to content

Commit b126921

Browse files
committed
expected error reduction bugs fixed
1 parent f409358 commit b126921

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

modAL/expected_error.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,18 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
6767
cloned_estimator.fit(X_new, y_new)
6868
refitted_proba = cloned_estimator.predict_proba(X)
6969
if loss is 'binary':
70-
loss = _proba_uncertainty(refitted_proba)
70+
nloss = _proba_uncertainty(refitted_proba)
7171
elif loss is 'log':
72-
loss = _proba_entropy(refitted_proba)
73-
74-
expected_error[x_idx] += np.sum(loss)*X_proba[x_idx, y_idx]
72+
nloss = _proba_entropy(refitted_proba)
7573

74+
expected_error[x_idx] += np.sum(nloss)*X_proba[x_idx, y_idx]
7675

7776
else:
7877
expected_error[x_idx] = np.inf
7978

8079
if not random_tie_break:
81-
query_idx = multi_argmax(expected_error, n_instances)
80+
query_idx = multi_argmax(-expected_error, n_instances)
8281
else:
83-
query_idx = shuffled_argmax(expected_error, n_instances)
82+
query_idx = shuffled_argmax(-expected_error, n_instances)
8483

8584
return query_idx, X[query_idx]

0 commit comments

Comments
 (0)