Skip to content

Commit 5b18f36

Browse files
authored
[feat] Use encode_document and encode_query in mine_hard_negatives (#3502)
* Use encode_document and encode_query in mine_hard_negatives * Patch test mock tokenize function
1 parent 6b08a5a commit 5b18f36

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

sentence_transformers/util/hard_negatives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def mine_hard_negatives(
369369
target_devices=None if isinstance(use_multi_process, bool) else use_multi_process
370370
)
371371
if corpus_embeddings is None:
372-
corpus_embeddings = model.encode(
372+
corpus_embeddings = model.encode_document(
373373
corpus,
374374
pool=pool,
375375
batch_size=batch_size,
@@ -380,7 +380,7 @@ def mine_hard_negatives(
380380
prompt=corpus_prompt,
381381
)
382382
if query_embeddings is None:
383-
query_embeddings = model.encode(
383+
query_embeddings = model.encode_query(
384384
queries,
385385
pool=pool,
386386
batch_size=batch_size,
@@ -393,7 +393,7 @@ def mine_hard_negatives(
393393
model.stop_multi_process_pool(pool)
394394
else:
395395
if corpus_embeddings is None:
396-
corpus_embeddings = model.encode(
396+
corpus_embeddings = model.encode_document(
397397
corpus,
398398
batch_size=batch_size,
399399
normalize_embeddings=True,
@@ -403,7 +403,7 @@ def mine_hard_negatives(
403403
prompt=corpus_prompt,
404404
)
405405
if query_embeddings is None:
406-
query_embeddings = model.encode(
406+
query_embeddings = model.encode_query(
407407
queries,
408408
batch_size=batch_size,
409409
normalize_embeddings=True,

tests/util/test_hard_negatives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,9 +810,9 @@ def test_mine_hard_negatives_with_prompt(paraphrase_distilroberta_base_v1_model:
810810
original_tokenize = model.tokenize
811811
tokenize_calls = []
812812

813-
def mock_tokenize(texts) -> dict[str, Tensor]:
813+
def mock_tokenize(texts, **kwargs) -> dict[str, Tensor]:
814814
tokenize_calls.append(texts)
815-
return original_tokenize(texts)
815+
return original_tokenize(texts, **kwargs)
816816

817817
# 2. Run without prompt - check that no prompt is added
818818
model.tokenize = mock_tokenize

0 commit comments

Comments
 (0)