Skip to content

Commit 2927cfe

Browse files
committed
fix: expected_error_reduction query strategy runtime improved
1 parent 55424a5 commit 2927cfe

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

modAL/expected_error.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
5151
# TODO: implement a proper cold-start
5252
return 0, X[0]
5353

54+
cloned_estimator = clone(learner.estimator)
55+
5456
for x_idx, x in enumerate(X):
5557
# subsample the data if needed
5658
if np.random.rand() <= p_subsample:
@@ -59,8 +61,8 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
5961
X_new = data_vstack((learner.X_training, x.reshape(1, -1)))
6062
y_new = data_vstack((learner.y_training, np.array(y).reshape(1, )))
6163

62-
refitted_estimator = clone(learner.estimator).fit(X_new, y_new)
63-
refitted_proba = refitted_estimator.predict_proba(X)
64+
cloned_estimator.fit(X_new, y_new)
65+
refitted_proba = cloned_estimator.predict_proba(X)
6466
if loss is 'binary':
6567
loss = _proba_uncertainty(refitted_proba)
6668
elif loss is 'log':

0 commit comments

Comments
 (0)