Skip to content

Commit 5ba2c4f

Browse files
committed
updated latest dense retriever models including ST, huggingface and LLM2Vec
1 parent 7dd3410 commit 5ba2c4f

19 files changed

+639
-115
lines changed

README.md

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Tested with python versions 3.9+
8383

8484
- Preprocess your own IR dataset or use one of the already-preprocessed 17 benchmark datasets
8585
- Wide settings included, covers diverse benchmarks useful for both academia and industry
86-
- Includes well-known retrieval architectures (lexical, dense, sparse and reranking-based)
86+
- Evaluates well-known retrieval architectures (lexical, dense, sparse and reranking-based)
8787
- Add and evaluate your own model in a easy framework using different state-of-the-art evaluation metrics
8888

8989
## :beers: Quick Example
@@ -132,14 +132,15 @@ results = retriever.retrieve(corpus, queries)
132132

133133
#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
134134
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
135+
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
135136

136137
### If you want to save your results and runfile (useful for reranking)
137138
results_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "results")
138139
os.makedirs(results_dir, exist_ok=True)
139140

140141
#### Save the evaluation runfile & results
141142
util.save_runfile(os.path.join(results_dir, f"{dataset}.run.trec"), results)
142-
util.save_results(os.path.join(results_dir, f"{dataset}.json"), ndcg, _map, recall, precision)
143+
util.save_results(os.path.join(results_dir, f"{dataset}.json"), ndcg, _map, recall, precision, mrr)
143144
```
144145

145146
## :beers: Available Datasets
@@ -227,13 +228,22 @@ If you find this repository helpful, feel free to cite our publication [BEIR: A
227228

228229
If you use any baseline score from the BEIR leaderboard, feel free to cite our publication [Resources for Brewing BEIR: Reproducible Reference Models and an Official Leaderboard](https://arxiv.org/abs/2306.07471)
229230
```
230-
@misc{kamalloo2023resources,
231-
title={Resources for Brewing BEIR: Reproducible Reference Models and an Official Leaderboard},
232-
author={Ehsan Kamalloo and Nandan Thakur and Carlos Lassance and Xueguang Ma and Jheng-Hong Yang and Jimmy Lin},
233-
year={2023},
234-
eprint={2306.07471},
235-
archivePrefix={arXiv},
236-
primaryClass={cs.IR}
231+
@inproceedings{kamalloo:2024,
232+
author = {Kamalloo, Ehsan and Thakur, Nandan and Lassance, Carlos and Ma, Xueguang and Yang, Jheng-Hong and Lin, Jimmy},
233+
title = {Resources for Brewing BEIR: Reproducible Reference Models and Statistical Analyses},
234+
year = {2024},
235+
isbn = {9798400704314},
236+
publisher = {Association for Computing Machinery},
237+
address = {New York, NY, USA},
238+
url = {https://doi.org/10.1145/3626772.3657862},
239+
doi = {10.1145/3626772.3657862},
240+
abstract = {BEIR is a benchmark dataset originally designed for zero-shot evaluation of retrieval models across 18 different domain/task combinations. In recent years, we have witnessed the growing popularity of models based on representation learning, which naturally begs the question: How effective are these models when presented with queries and documents that differ from the training data? While BEIR was designed to answer this question, our work addresses two shortcomings that prevent the benchmark from achieving its full potential: First, the sophistication of modern neural methods and the complexity of current software infrastructure create barriers to entry for newcomers. To this end, we provide reproducible reference implementations that cover learned dense and sparse models. Second, comparisons on BEIR are performed by reducing scores from heterogeneous datasets into a single average that is difficult to interpret. To remedy this, we present meta-analyses focusing on effect sizes across datasets that are able to accurately quantify model differences. By addressing both shortcomings, our work facilitates future explorations in a range of interesting research questions.},
241+
booktitle = {Proceedings of the 47th International ACM SIGIR Conference on Research and Development in Information Retrieval},
242+
pages = {1431–1440},
243+
numpages = {10},
244+
keywords = {domain generalization, evaluation, reproducibility},
245+
location = {Washington DC, USA},
246+
series = {SIGIR '24}
237247
}
238248
```
239249

beir/retrieval/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from .bpr import BinarySentenceBERT
44
from .huggingface import HuggingFace
5+
from .llm2vec import LLM2Vec
6+
from .nvembed import NVEmbed
57
from .sentence_bert import SentenceBERT
68
from .sparta import SPARTA
79
from .splade import SPLADE
@@ -11,6 +13,8 @@
1113
__all__ = [
1214
"BinarySentenceBERT",
1315
"HuggingFace",
16+
"LLM2Vec",
17+
"NVEmbed",
1418
"SentenceBERT",
1519
"SPARTA",
1620
"SPLADE",

beir/retrieval/models/huggingface.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,21 @@
2121
POOL_FUNC = {"cls": cls_pooling, "mean": mean_pooling, "eos": eos_pooling}
2222

2323

24-
def get_peft_model(peft_model_name: str) -> PeftModel:
24+
def get_peft_model(peft_model_name: str, **kwargs) -> tuple[PeftModel, str]:
2525
config = PeftConfig.from_pretrained(peft_model_name)
26-
base_model = AutoModel.from_pretrained(config.base_model_name_or_path)
26+
logger.info(f"Loading Auto Model from {config.base_model_name_or_path} for PEFT model")
27+
base_model = AutoModel.from_pretrained(
28+
config.base_model_name_or_path,
29+
device_map="auto",
30+
attn_implementation=kwargs.get("attn_implementation", "eager"),
31+
torch_dtype=kwargs.get("torch_dtype", "auto"),
32+
trust_remote_code=True,
33+
cache_dir=kwargs.get("cache_dir", None),
34+
)
35+
logger.info(f"Loading PEFT model from {peft_model_name}")
2736
model = PeftModel.from_pretrained(base_model, peft_model_name)
2837
model = model.merge_and_unload()
29-
return model
38+
return model, config.base_model_name_or_path
3039

3140

3241
class HuggingFace:
@@ -43,18 +52,23 @@ def __init__(
4352
**kwargs,
4453
):
4554
self.sep = sep
46-
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
47-
if self.tokenizer.pad_token_id is None:
48-
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
49-
self.tokenizer.padding_side = "right"
50-
5155
if peft_model_path:
52-
self.model = get_peft_model(peft_model_path)
56+
self.model, base_model_path = get_peft_model(peft_model_path, **kwargs)
57+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=True)
5358
else:
5459
self.model = AutoModel.from_pretrained(
55-
model_path, device_map="auto", torch_dtype=kwargs.get("torch_dtype", "auto"), trust_remote_code=True
60+
model_path,
61+
device_map="auto",
62+
torch_dtype=kwargs.get("torch_dtype", "auto"),
63+
trust_remote_code=True,
64+
attn_implementation=kwargs.get("attn_implementation", "default"),
65+
cache_dir=kwargs.get("cache_dir", None),
5666
)
67+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
5768
self.model.eval()
69+
if self.tokenizer.pad_token_id is None:
70+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
71+
self.tokenizer.padding_side = "right"
5872
self.max_length = max_length if max_length else self.tokenizer.model_max_length
5973
self.normalize = normalize # Normalize the embeddings
6074
self.append_eos_token = append_eos_token # Add eos token to the input

beir/retrieval/models/llm2vec.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from __future__ import annotations
2+
3+
import importlib.util
4+
import logging
5+
6+
if importlib.util.find_spec("llm2vec") is not None:
7+
from llm2vec import LLM2Vec as LLM2VecOriginal
8+
9+
import numpy as np
10+
import torch
11+
import torch.nn.functional as F
12+
from torch import Tensor
13+
from tqdm.autonotebook import trange
14+
15+
from .util import extract_corpus_sentences
16+
17+
logger = logging.getLogger(__name__)
18+
19+
POOLING_MODES = {
20+
"mean": "mean",
21+
"weighted_mean": "weighted_mean",
22+
"eos": "eos_token",
23+
"bos_token": "bos_token",
24+
"last_token": "last_token",
25+
}
26+
27+
28+
class LLM2Vec:
29+
def __init__(
30+
self,
31+
model_path: str | tuple = None,
32+
max_length: int = None,
33+
sep: str = " ",
34+
pooling: str = "mean",
35+
normalize: bool = True,
36+
prompts: dict[str, str] = None,
37+
peft_model_path: str = None,
38+
**kwargs,
39+
):
40+
self.sep = sep
41+
self.normalize = normalize
42+
if pooling not in POOLING_MODES:
43+
raise ValueError(f"Pooling mode {pooling} not supported. Choose from {list(POOLING_MODES.keys())}")
44+
45+
self.model = LLM2VecOriginal.from_pretrained(
46+
base_model_name_or_path=model_path,
47+
peft_model_name_or_path=peft_model_path,
48+
pooling_mode=POOLING_MODES[pooling],
49+
max_length=max_length,
50+
**kwargs,
51+
)
52+
53+
if prompts:
54+
self.query_prefix = prompts.get("query", "")
55+
self.doc_prefix = prompts.get("passage", "")
56+
57+
def _append_eos_token(self, texts, pad_to_multiple_of: int = 16):
58+
"""Tokenizes the input texts and pads the tokenized input to the max_length with the eos token"""
59+
collated_texts = self.tokenizer(
60+
texts,
61+
padding=False,
62+
truncation=True,
63+
max_length=self.max_length - 1 if self.append_eos_token else self.max_length,
64+
return_attention_mask=False,
65+
return_token_type_ids=False,
66+
add_special_tokens=True,
67+
)
68+
collated_texts["input_ids"] = [x + [self.tokenizer.eos_token_id] for x in collated_texts["input_ids"]]
69+
collated_texts = self.tokenizer.pad(
70+
collated_texts,
71+
padding=True,
72+
pad_to_multiple_of=pad_to_multiple_of,
73+
return_attention_mask=True,
74+
return_tensors="pt",
75+
)
76+
return collated_texts
77+
78+
def encode_queries(self, queries: list[str], batch_size: int = 16, **kwargs) -> list[Tensor] | np.ndarray | Tensor:
79+
query_embeddings = []
80+
81+
with torch.no_grad():
82+
for start_idx in trange(0, len(queries), batch_size):
83+
sub_queries = [[self.query_prefix, query] for query in queries[start_idx : start_idx + batch_size]]
84+
query_embeddings += self.model.encode(sub_queries, batch_size=batch_size, show_progress_bar=False)
85+
86+
query_embeddings = torch.stack(query_embeddings)
87+
88+
if self.normalize:
89+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
90+
91+
return query_embeddings
92+
93+
def encode_corpus(
94+
self, corpus: list[dict[str, str]] | dict[str, list] | list[str], batch_size: int = 8, **kwargs
95+
) -> list[Tensor] | np.ndarray | Tensor:
96+
corpus_embeddings = []
97+
sentences = extract_corpus_sentences(corpus=corpus, sep=self.sep)
98+
99+
with torch.no_grad():
100+
for start_idx in trange(0, len(sentences), batch_size):
101+
if self.doc_prefix:
102+
sub_sentences = [
103+
[self.doc_prefix, sentence] for sentence in sentences[start_idx : start_idx + batch_size]
104+
]
105+
else:
106+
sub_sentences = sentences[start_idx : start_idx + batch_size]
107+
corpus_embeddings += self.model.encode(sub_sentences, batch_size=batch_size, show_progress_bar=False)
108+
109+
corpus_embeddings = torch.stack(corpus_embeddings)
110+
111+
if self.normalize:
112+
corpus_embeddings = F.normalize(corpus_embeddings, p=2, dim=1)
113+
114+
return corpus_embeddings

beir/retrieval/models/nvembed.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
import numpy as np
6+
import torch
7+
import torch.nn.functional as F
8+
from torch import Tensor
9+
from tqdm.autonotebook import trange
10+
from transformers import AutoModel
11+
12+
from .pooling import cls_pooling, eos_pooling, mean_pooling
13+
from .util import extract_corpus_sentences
14+
15+
logger = logging.getLogger(__name__)
16+
17+
POOL_FUNC = {"cls": cls_pooling, "mean": mean_pooling, "eos": eos_pooling}
18+
19+
20+
class NVEmbed:
21+
def __init__(
22+
self,
23+
model_path: str | tuple = None,
24+
max_length: int = None,
25+
sep: str = " ",
26+
pooling: str = "mean",
27+
normalize: bool = False,
28+
prompts: dict[str, str] = None,
29+
**kwargs,
30+
):
31+
self.sep = sep
32+
self.model = AutoModel.from_pretrained(
33+
model_path, device_map="auto", torch_dtype=kwargs.get("torch_dtype", "auto"), trust_remote_code=True
34+
)
35+
# self.model.eval()
36+
self.max_length = max_length if max_length else self.tokenizer.model_max_length
37+
self.normalize = normalize # Normalize the embeddings
38+
39+
if pooling not in ["cls", "mean", "eos"]:
40+
raise ValueError("Supported Pooling techniques should be either 'cls', 'mean' or 'eos'")
41+
self.pooling_func = POOL_FUNC[pooling]
42+
43+
if prompts:
44+
self.query_prefix = prompts.get("query", "")
45+
self.doc_prefix = prompts.get("passage", "")
46+
47+
def encode_queries(self, queries: list[str], batch_size: int = 16, **kwargs) -> list[Tensor] | np.ndarray | Tensor:
48+
query_embeddings = []
49+
50+
with torch.no_grad():
51+
for start_idx in trange(0, len(queries), batch_size):
52+
sub_queries = [self.query_prefix + query for query in queries[start_idx : start_idx + batch_size]]
53+
query_embeddings += self.model.encode(
54+
sub_queries, instruction=self.query_prefix, max_length=self.max_length
55+
)
56+
57+
query_embeddings = torch.stack(query_embeddings)
58+
59+
if self.normalize:
60+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
61+
62+
return query_embeddings
63+
64+
def encode_corpus(
65+
self, corpus: list[dict[str, str]] | dict[str, list] | list[str], batch_size: int = 8, **kwargs
66+
) -> list[Tensor] | np.ndarray | Tensor:
67+
corpus_embeddings = []
68+
sentences = extract_corpus_sentences(corpus=corpus, sep=self.sep)
69+
70+
with torch.no_grad():
71+
for start_idx in trange(0, len(sentences), batch_size):
72+
sub_sentences = [
73+
self.doc_prefix + sentence for sentence in sentences[start_idx : start_idx + batch_size]
74+
]
75+
corpus_embeddings += self.model.encode(
76+
sub_sentences, instruction=self.doc_prefix, max_length=self.max_length
77+
)
78+
79+
corpus_embeddings = torch.stack(corpus_embeddings)
80+
81+
if self.normalize:
82+
corpus_embeddings = F.normalize(corpus_embeddings, p=2, dim=1)
83+
84+
return corpus_embeddings

beir/retrieval/models/sentence_bert.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,42 @@ class SentenceBERT:
1717
def __init__(
1818
self,
1919
model_path: str | tuple = None,
20+
max_length: int = None,
2021
sep: str = " ",
2122
prompts: dict[str, str] = None,
23+
prompt_names: dict[str, str] = None,
2224
**kwargs,
2325
):
2426
self.sep = sep
27+
self.max_length = max_length
2528

2629
if isinstance(model_path, str):
27-
self.q_model = SentenceTransformer(model_path, kwargs)
30+
self.q_model = SentenceTransformer(model_path, **kwargs)
2831
self.doc_model = self.q_model
2932

3033
elif isinstance(model_path, tuple):
31-
self.q_model = SentenceTransformer(model_path[0], kwargs)
32-
self.doc_model = SentenceTransformer(model_path[1], kwargs)
34+
self.q_model = SentenceTransformer(model_path[0], **kwargs)
35+
self.doc_model = SentenceTransformer(model_path[1], **kwargs)
3336

34-
self.query_prefix = ""
35-
self.doc_prefix = ""
37+
if self.max_length:
38+
self.q_model.max_seq_length = self.max_length
39+
self.doc_model.max_seq_length = self.max_length
40+
41+
self.query_prefix, self.query_prompt_name = None, None
42+
self.doc_prefix, self.doc_prompt_name = None, None
3643

3744
# Checks if prompts are not set in Sentence Transformers but required during inference
3845
if prompts and (len(self.q_model.prompts) or len(self.doc_model.prompts) == 0):
3946
self.query_prefix = prompts["query"]
4047
self.doc_prefix = prompts["passage"]
4148

49+
if prompt_names:
50+
self.query_prompt_name = prompt_names["query"]
51+
self.doc_prompt_name = prompt_names["passage"]
52+
53+
logger.info(f"Query prompt: {self.query_prefix}, Passage prompt: {self.doc_prefix}")
54+
logger.info(f"Query prompt name: {self.query_prompt_name}, Passage prompt name: {self.doc_prompt_name}")
55+
4256
def get_similarity(self):
4357
return self.q_model.similarity
4458

@@ -74,7 +88,9 @@ def stop_multi_process_pool(self, pool: dict[str, object]):
7488

7589
def encode_queries(self, queries: list[str], batch_size: int = 16, **kwargs) -> list[Tensor] | np.ndarray | Tensor:
7690
return self.q_model.encode(
77-
[self.query_prefix + query for query in queries],
91+
queries,
92+
prompt=self.query_prefix,
93+
prompt_name=self.query_prompt_name,
7894
batch_size=batch_size,
7995
**kwargs,
8096
)
@@ -87,7 +103,9 @@ def encode_corpus(
87103
) -> list[Tensor] | np.ndarray | Tensor:
88104
sentences = extract_corpus_sentences(corpus=corpus, sep=self.sep)
89105
return self.doc_model.encode(
90-
[self.doc_prefix + sentence for sentence in sentences],
106+
sentences,
107+
prompt=self.doc_prefix,
108+
prompt_name=self.doc_prompt_name,
91109
batch_size=batch_size,
92110
**kwargs,
93111
)

0 commit comments

Comments
 (0)