Skip to content

Commit 3c01821

Browse files
committed
queried instance removed from the loss calculation
1 parent b126921 commit 3c01821

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

modAL/expected_error.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,14 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
5959
for x_idx, x in enumerate(X):
6060
# subsample the data if needed
6161
if np.random.rand() <= p_subsample:
62+
X_reduced = np.delete(X, x_idx, axis=0)
6263
# estimate the expected error
6364
for y_idx, y in enumerate(possible_labels):
6465
X_new = data_vstack((learner.X_training, np.expand_dims(x, axis=0)))
6566
y_new = data_vstack((learner.y_training, np.array(y).reshape(1,)))
6667

6768
cloned_estimator.fit(X_new, y_new)
68-
refitted_proba = cloned_estimator.predict_proba(X)
69+
refitted_proba = cloned_estimator.predict_proba(X_reduced)
6970
if loss is 'binary':
7071
nloss = _proba_uncertainty(refitted_proba)
7172
elif loss is 'log':

0 commit comments

Comments
 (0)