@@ -79,9 +79,11 @@ def select_instance(
7979 Index of the best index from X chosen to be labelled; a single record from our unlabeled set that is considered
8080 the most optimal incremental record for including in our query set.
8181 """
82+ X_pool_masked = X_pool [mask ]
83+
8284 # Extract the number of labeled and unlabeled records.
83- n_labeled_records , _ = X_training .shape
84- n_unlabeled , _ = X_pool [ mask ] .shape
85+ n_labeled_records , * rest = X_training .shape
86+ n_unlabeled , * rest = X_pool_masked .shape
8587
8688 # Determine our alpha parameter as |U| / (|U| + |D|). Note that because we
8789 # append to X_training and remove from X_pool within `ranked_batch`,
@@ -90,10 +92,15 @@ def select_instance(
9092
9193 # Compute pairwise distance (and then similarity) scores from every unlabeled record
9294 # to every record in X_training. The result is an array of shape (n_samples, ).
95+
9396 if n_jobs == 1 or n_jobs is None :
94- _ , distance_scores = pairwise_distances_argmin_min (X_pool [mask ], X_training , metric = metric )
97+ _ , distance_scores = pairwise_distances_argmin_min (X_pool_masked .reshape (n_unlabeled , - 1 ),
98+ X_training .reshape (n_labeled_records , - 1 ),
99+ metric = metric )
95100 else :
96- distance_scores = pairwise_distances (X_pool [mask ], X_training , metric = metric , n_jobs = n_jobs ).min (axis = 1 )
101+ distance_scores = pairwise_distances (X_pool_masked .reshape (n_unlabeled , - 1 ),
102+ X_training .reshape (n_labeled_records , - 1 ),
103+ metric = metric , n_jobs = n_jobs ).min (axis = 1 )
97104
98105 similarity_scores = 1 / (1 + distance_scores )
99106
@@ -103,11 +110,11 @@ def select_instance(
103110
104111 # Isolate and return our best instance for labeling as the one with the largest score.
105112 best_instance_index_in_unlabeled = np .argmax (scores )
106- n_pool , _ = X_pool .shape
113+ n_pool , * rest = X_pool .shape
107114 unlabeled_indices = [i for i in range (n_pool ) if mask [i ]]
108115 best_instance_index = unlabeled_indices [best_instance_index_in_unlabeled ]
109116 mask [best_instance_index ] = 0
110- return best_instance_index , X_pool [best_instance_index ]. reshape ( 1 , - 1 ), mask
117+ return best_instance_index , np . expand_dims ( X_pool [best_instance_index ], axis = 0 ), mask
111118
112119
113120def ranked_batch (classifier : Union [BaseLearner , BaseCommittee ],
0 commit comments