Skip to content

Commit fc8e2e0

Browse files
committed
expected_error_reduction fixed for multidimensional data
1 parent a8eca52 commit fc8e2e0

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

modAL/expected_error.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
6161
if np.random.rand() <= p_subsample:
6262
# estimate the expected error
6363
for y_idx, y in enumerate(possible_labels):
64-
X_new = data_vstack((learner.X_training, x.reshape(1, -1)))
65-
y_new = data_vstack((learner.y_training, np.array(y).reshape(1, )))
64+
X_new = data_vstack((learner.X_training, np.expand_dims(x, axis=0)))
65+
y_new = data_vstack((learner.y_training, np.array(y).reshape(1,)))
6666

6767
cloned_estimator.fit(X_new, y_new)
6868
refitted_proba = cloned_estimator.predict_proba(X)
@@ -73,6 +73,7 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
7373

7474
expected_error[x_idx] += np.sum(loss)*X_proba[x_idx, y_idx]
7575

76+
7677
else:
7778
expected_error[x_idx] = np.inf
7879

tests/example_tests/multidimensional_data.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import numpy as np
2+
from sklearn.base import BaseEstimator
3+
24
from modAL.models import ActiveLearner
35
from modAL.uncertainty import margin_sampling, entropy_sampling
46
from modAL.batch import uncertainty_batch_sampling
57
from modAL.expected_error import expected_error_reduction
68

79

8-
class MockClassifier:
10+
class MockClassifier(BaseEstimator):
911
def __init__(self, n_classes=2):
1012
self.n_classes = n_classes
1113

@@ -21,11 +23,11 @@ def predict_proba(self, X):
2123

2224
if __name__ == '__main__':
2325
X_train = np.random.rand(10, 5, 5)
24-
y_train = np.random.rand(10, 1)
26+
y_train = np.random.randint(0, 2, size=10)
2527
X_pool = np.random.rand(10, 5, 5)
26-
y_pool = np.random.rand(10, 1)
28+
y_pool = np.random.randint(0, 2, size=10)
2729

28-
strategies = [margin_sampling, entropy_sampling, uncertainty_batch_sampling]
30+
strategies = [margin_sampling, entropy_sampling, uncertainty_batch_sampling, expected_error_reduction]
2931

3032
for query_strategy in strategies:
3133
print("testing %s..." % query_strategy.__name__)

0 commit comments

Comments
 (0)