Skip to content

Commit 73d5318

Browse files
authored
Fix the feature extraction bug (#1942)
1 parent f089a18 commit 73d5318

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

applications/neural_search/recall/milvus/feature_extract.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,13 @@ def predict(self, data, tokenizer):
156156
logits = self.output_handle.copy_to_cpu()
157157
all_embeddings.append(logits)
158158
examples = []
159-
159+
if (len(examples) > 0):
160+
input_ids, segment_ids = batchify_fn(examples)
161+
self.input_handles[0].copy_from_cpu(input_ids)
162+
self.input_handles[1].copy_from_cpu(segment_ids)
163+
self.predictor.run()
164+
logits = self.output_handle.copy_to_cpu()
165+
all_embeddings.append(logits)
160166
all_embeddings = np.concatenate(all_embeddings, axis=0)
161167
np.save('corpus_embedding', all_embeddings)
162168

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
CUDA_VISIBLE_DEVICES=2 python feature_extract.py \
22
--model_dir=./output \
3-
--corpus_file "data/milvus_data.csv"
3+
--corpus_file "milvus/milvus_data.csv"

0 commit comments

Comments
 (0)