Skip to content

Commit 0e66d27

Browse files
committed
Merge branch 'zhangyu94-fix-ranked-batch' into dev
2 parents 536636c + e973a32 commit 0e66d27

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

modAL/batch.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
def select_cold_start_instance(X: modALinput,
1717
metric: Union[str, Callable],
18-
n_jobs: Union[int, None]) -> modALinput:
18+
n_jobs: Union[int, None]) -> Tuple[int, modALinput]:
1919
"""
2020
Define what to do if our batch-mode sampling doesn't have any labeled data -- a cold start.
2121
@@ -35,15 +35,16 @@ def select_cold_start_instance(X: modALinput,
3535
n_jobs: This parameter is passed to :func:`~sklearn.metrics.pairwise.pairwise_distances`.
3636
3737
Returns:
38-
Best instance for cold-start.
38+
Index of the best cold-start instance from `X` chosen to be labelled; record of the best cold-start instance
39+
from `X` chosen to be labelled.
3940
"""
4041
# Compute all pairwise distances in our unlabeled data and obtain the row-wise average for each of our records in X.
4142
n_jobs = n_jobs if n_jobs else 1
4243
average_distances = np.mean(pairwise_distances(X, metric=metric, n_jobs=n_jobs), axis=0)
4344

4445
# Isolate and return our best instance for labeling as the record with the least average distance.
4546
best_coldstart_instance_index = np.argmin(average_distances)
46-
return X[best_coldstart_instance_index].reshape(1, -1)
47+
return best_coldstart_instance_index, X[best_coldstart_instance_index].reshape(1, -1)
4748

4849

4950
def select_instance(
@@ -133,14 +134,16 @@ def ranked_batch(classifier: Union[BaseLearner, BaseCommittee],
133134
The indices of the top n_instances ranked unlabelled samples.
134135
"""
135136
# Make a local copy of our classifier's training data.
137+
# Define our record container and record the best cold start instance in the case of cold start.
136138
if classifier.X_training is None:
137-
labeled = select_cold_start_instance(X=unlabeled, metric=metric, n_jobs=n_jobs)
139+
best_coldstart_instance_index, labeled = select_cold_start_instance(X=unlabeled, metric=metric, n_jobs=n_jobs)
140+
instance_index_ranking = [best_coldstart_instance_index]
138141
elif classifier.X_training.shape[0] > 0:
139142
labeled = classifier.X_training[:]
140-
141-
# Define our record container and the maximum number of records to sample.
142-
instance_index_ranking = []
143-
ceiling = np.minimum(unlabeled.shape[0], n_instances)
143+
instance_index_ranking = []
144+
145+
# The maximum number of records to sample.
146+
ceiling = np.minimum(unlabeled.shape[0], n_instances) - len(instance_index_ranking)
144147

145148
# mask for unlabeled initialized as transparent
146149
mask = np.ones(unlabeled.shape[0], np.bool)

0 commit comments

Comments
 (0)