Skip to content

Commit 2163bea

Browse files
authored
Merge pull request #1340 from hanhainebula/master
Fix bugs
2 parents db83452 + 63324c3 commit 2163bea

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed
Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
from .air_bench import AIRBenchEvalModelArgs, AIRBenchEvalArgs, AIRBenchEvalRunner
2-
from .beir import *
3-
# from miracle import *
4-
# from mkqa import *
5-
# from mldr import *
6-
# from msmarco import *
7-
from mteb import *

FlagEmbedding/evaluation/beir/data_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
409409
Returns:
410410
datasets.DatasetDict: A dict of relevance of query and document.
411411
"""
412-
checked_split = self.check_splits(split)
412+
checked_split = self.check_splits(split, dataset_name=dataset_name)
413413
if len(checked_split) == 0:
414414
raise ValueError(f"Split {split} not found in the dataset.")
415415
split = checked_split[0]
@@ -450,7 +450,7 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None,
450450
Returns:
451451
datasets.DatasetDict: A dict of queries with id as key, query text as value.
452452
"""
453-
checked_split = self.check_splits(split)
453+
checked_split = self.check_splits(split, dataset_name=dataset_name)
454454
if len(checked_split) == 0:
455455
raise ValueError(f"Split {split} not found in the dataset.")
456456
split = checked_split[0]

scripts/hn_mine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def find_knn_neg(
159159
p_vecs = model.encode(corpus)
160160
print(f'inferencing embedding for queries (number={len(queries)})--------------')
161161
q_vecs = model.encode_queries(queries)
162+
163+
# check if the embeddings are in dictionary format: M3Embedder
164+
if isinstance(p_vecs, dict):
165+
p_vecs = p_vecs["dense_vecs"]
166+
if isinstance(q_vecs, dict):
167+
q_vecs = q_vecs["dense_vecs"]
162168

163169
print('create index and search------------------')
164170
index = create_index(p_vecs, use_gpu=use_gpu)

0 commit comments

Comments
 (0)