Skip to content

Commit a8eca52

Browse files
committed
batch uncertainty sampling fixed for higher dimensional datasets
1 parent c4d62d6 commit a8eca52

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

modAL/batch.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def select_instance(
7979
Index of the best index from X chosen to be labelled; a single record from our unlabeled set that is considered
8080
the most optimal incremental record for including in our query set.
8181
"""
82+
X_pool_masked = X_pool[mask]
83+
8284
# Extract the number of labeled and unlabeled records.
83-
n_labeled_records, _ = X_training.shape
84-
n_unlabeled, _ = X_pool[mask].shape
85+
n_labeled_records, *rest = X_training.shape
86+
n_unlabeled, *rest = X_pool_masked.shape
8587

8688
# Determine our alpha parameter as |U| / (|U| + |D|). Note that because we
8789
# append to X_training and remove from X_pool within `ranked_batch`,
@@ -90,10 +92,15 @@ def select_instance(
9092

9193
# Compute pairwise distance (and then similarity) scores from every unlabeled record
9294
# to every record in X_training. The result is an array of shape (n_samples, ).
95+
9396
if n_jobs == 1 or n_jobs is None:
94-
_, distance_scores = pairwise_distances_argmin_min(X_pool[mask], X_training, metric=metric)
97+
_, distance_scores = pairwise_distances_argmin_min(X_pool_masked.reshape(n_unlabeled, -1),
98+
X_training.reshape(n_labeled_records, -1),
99+
metric=metric)
95100
else:
96-
distance_scores = pairwise_distances(X_pool[mask], X_training, metric=metric, n_jobs=n_jobs).min(axis=1)
101+
distance_scores = pairwise_distances(X_pool_masked.reshape(n_unlabeled, -1),
102+
X_training.reshape(n_labeled_records, -1),
103+
metric=metric, n_jobs=n_jobs).min(axis=1)
97104

98105
similarity_scores = 1 / (1 + distance_scores)
99106

@@ -103,11 +110,11 @@ def select_instance(
103110

104111
# Isolate and return our best instance for labeling as the one with the largest score.
105112
best_instance_index_in_unlabeled = np.argmax(scores)
106-
n_pool, _ = X_pool.shape
113+
n_pool, *rest = X_pool.shape
107114
unlabeled_indices = [i for i in range(n_pool) if mask[i]]
108115
best_instance_index = unlabeled_indices[best_instance_index_in_unlabeled]
109116
mask[best_instance_index] = 0
110-
return best_instance_index, X_pool[best_instance_index].reshape(1, -1), mask
117+
return best_instance_index, np.expand_dims(X_pool[best_instance_index], axis=0), mask
111118

112119

113120
def ranked_batch(classifier: Union[BaseLearner, BaseCommittee],

tests/core_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ def test_strategies(self):
10711071
class TestExamples(unittest.TestCase):
10721072

10731073
def test_examples(self):
1074+
import example_tests.multidimensional_data
10741075
import example_tests.active_regression
10751076
import example_tests.bagging
10761077
import example_tests.ensemble
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
from modAL.models import ActiveLearner
3+
from modAL.uncertainty import margin_sampling, entropy_sampling
4+
from modAL.batch import uncertainty_batch_sampling
5+
from modAL.expected_error import expected_error_reduction
6+
7+
8+
class MockClassifier:
9+
def __init__(self, n_classes=2):
10+
self.n_classes = n_classes
11+
12+
def fit(self, X, y):
13+
return self
14+
15+
def predict(self, X):
16+
return np.random.randint(0, self.n_classes, shape=(len(X), 1))
17+
18+
def predict_proba(self, X):
19+
return np.ones(shape=(len(X), self.n_classes))/self.n_classes
20+
21+
22+
if __name__ == '__main__':
23+
X_train = np.random.rand(10, 5, 5)
24+
y_train = np.random.rand(10, 1)
25+
X_pool = np.random.rand(10, 5, 5)
26+
y_pool = np.random.rand(10, 1)
27+
28+
strategies = [margin_sampling, entropy_sampling, uncertainty_batch_sampling]
29+
30+
for query_strategy in strategies:
31+
print("testing %s..." % query_strategy.__name__)
32+
# max margin sampling
33+
learner = ActiveLearner(
34+
estimator=MockClassifier(), query_strategy=query_strategy,
35+
X_training=X_train, y_training=y_train
36+
)
37+
learner.query(X_pool)
38+
learner.teach(X_pool, y_pool)

0 commit comments

Comments
 (0)