Skip to content

Commit d2a4dbb

Browse files
committed
feat: add beir benchmark for embedding models
1 parent b1ecde0 commit d2a4dbb

File tree

3 files changed

+510
-0
lines changed

3 files changed

+510
-0
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
import argparse
2+
import os
3+
import uuid
4+
import pathlib
5+
from beir import util, LoggingHandler
6+
from beir.datasets.data_loader import GenericDataLoader
7+
8+
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
9+
from llama_stack.apis.tools import RAGQueryConfig
10+
from llama_stack_client.types import Document
11+
12+
import numpy as np
13+
import pytrec_eval
14+
15+
import itertools
16+
import logging
17+
18+
DEFAULT_DATASET_NAMES = ["scifact"]
19+
DEFAULT_CUSTOM_DATASETS_URLS = []
20+
DEFAULT_EMBEDDING_MODELS = ["granite-embedding-30m", "granite-embedding-125m"]
21+
DEFAULT_BATCH_SIZE = 150
22+
23+
24+
def parse_args():
25+
parser = argparse.ArgumentParser(
26+
description="Benchmark embedding models with BEIR datasets"
27+
)
28+
29+
parser.add_argument(
30+
"--dataset-names",
31+
nargs="+",
32+
type=str,
33+
default=DEFAULT_DATASET_NAMES,
34+
help=f"List of BEIR datasets to evaluate (default: {DEFAULT_DATASET_NAMES})",
35+
)
36+
37+
parser.add_argument(
38+
"--custom-datasets-urls",
39+
nargs="+",
40+
type=str,
41+
default=DEFAULT_CUSTOM_DATASETS_URLS,
42+
help=f"Custom URLs for datasets (default: {DEFAULT_CUSTOM_DATASETS_URLS})",
43+
)
44+
45+
parser.add_argument(
46+
"--embedding-models",
47+
nargs="+",
48+
type=str,
49+
default=DEFAULT_EMBEDDING_MODELS,
50+
help=f"List of embedding models to evaluate (default: {DEFAULT_EMBEDDING_MODELS})",
51+
)
52+
53+
parser.add_argument(
54+
"--batch-size",
55+
type=int,
56+
default=150,
57+
help=f"Batch size for injecting documents (default: {DEFAULT_BATCH_SIZE})",
58+
)
59+
60+
return parser.parse_args()
61+
62+
63+
logging.basicConfig(
64+
format="%(asctime)s - %(message)s",
65+
datefmt="%Y-%m-%d %H:%M:%S",
66+
level=logging.INFO,
67+
handlers=[LoggingHandler()],
68+
)
69+
70+
71+
# Load BEIR dataset
72+
def load_beir_dataset(dataset_name: str, custom_datasets_pairs: dict):
73+
if custom_datasets_pairs == {}:
74+
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
75+
else:
76+
url = custom_datasets_pairs[dataset_name]
77+
78+
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
79+
data_path = os.path.join(out_dir, dataset_name)
80+
81+
if not os.path.isdir(data_path):
82+
data_path = util.download_and_unzip(url, out_dir)
83+
84+
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
85+
return corpus, queries, qrels
86+
87+
88+
"""
89+
This function is used to inject documents into the LlamaStack vector database.
90+
The documents are processed in batches to avoid memory issues.
91+
"""
92+
93+
94+
def inject_documents(
95+
llama_stack_client: LlamaStackAsLibraryClient, corpus: dict, batch_size: int, vector_db_provider_id: str, embedding_model: str
96+
) -> str:
97+
vector_db_id = f"beir-rag-eval-{embedding_model}-{uuid.uuid4().hex}"
98+
99+
llama_stack_client.vector_dbs.register(
100+
vector_db_id=vector_db_id,
101+
embedding_model=embedding_model,
102+
provider_id=vector_db_provider_id,
103+
)
104+
105+
# Convert corpus into Documents and process in batches
106+
corpus_items = list(corpus.items())
107+
total_docs = len(corpus_items)
108+
109+
print(f"Processing {total_docs} documents in batches of {batch_size}")
110+
111+
for i in range(0, total_docs, batch_size):
112+
batch_items = corpus_items[i : i + batch_size]
113+
documents_batch = [
114+
Document(
115+
document_id=doc_id,
116+
content=data["title"] + " " + data["text"],
117+
mime_type="text/plain",
118+
metadata={},
119+
)
120+
for doc_id, data in batch_items
121+
]
122+
123+
print(
124+
f"Inserting batch {i // batch_size + 1}/{(total_docs + batch_size - 1) // batch_size} ({len(documents_batch)} docs)"
125+
)
126+
127+
llama_stack_client.tool_runtime.rag_tool.insert(
128+
documents=documents_batch,
129+
vector_db_id=vector_db_id,
130+
chunk_size_in_tokens=512,
131+
timeout=3600,
132+
)
133+
134+
print(f"Successfully inserted all {total_docs} documents")
135+
return vector_db_id
136+
137+
138+
# LlamaStack RAG Retriever
139+
class LlamaStackRAGRetriever:
140+
def __init__(self, vector_db_id: str, query_config: RAGQueryConfig, top_k: int = 10):
141+
self.llama_stack_client = llama_stack_client
142+
self.vector_db_id = vector_db_id
143+
self.query_config = query_config
144+
self.top_k = top_k
145+
146+
def retrieve(self, queries, top_k=None):
147+
results = {}
148+
top_k = top_k or self.top_k
149+
150+
for qid, query in queries.items():
151+
rag_results = self.llama_stack_client.tool_runtime.rag_tool.query(
152+
vector_db_ids=[self.vector_db_id],
153+
content=query,
154+
query_config={**self.query_config, "max_chunks": top_k},
155+
)
156+
157+
doc_ids = rag_results.metadata.get("document_ids", [])
158+
scores = {doc_id: 1.0 - (i * 0.01) for i, doc_id in enumerate(doc_ids)}
159+
160+
results[qid] = scores
161+
162+
return results
163+
164+
165+
# Adapted from https://github.com/opendatahub-io/llama-stack-demos/blob/main/demos/rag_eval/Agentic_RAG_with_reference_eval.ipynb
166+
def permutation_test_for_paired_samples(scores_a: list, scores_b: list, iterations: int = 10_000):
167+
"""
168+
Performs a permutation test of a given statistic on provided data.
169+
"""
170+
171+
from scipy.stats import permutation_test
172+
173+
def _statistic(x, y, axis):
174+
return np.mean(x, axis=axis) - np.mean(y, axis=axis)
175+
176+
result = permutation_test(
177+
data=(scores_a, scores_b),
178+
statistic=_statistic,
179+
n_resamples=iterations,
180+
alternative="two-sided",
181+
permutation_type="samples",
182+
)
183+
return float(result.pvalue)
184+
185+
186+
# Adapted from https://github.com/opendatahub-io/llama-stack-demos/blob/main/demos/rag_eval/Agentic_RAG_with_reference_eval.ipynb
187+
def print_stats_significance(scores_a: list, scores_b: list, overview_label: str, label_a: str, label_b: str):
188+
mean_score_a = np.mean(scores_a)
189+
mean_score_b = np.mean(scores_b)
190+
191+
p_value = permutation_test_for_paired_samples(scores_a, scores_b)
192+
print(overview_label)
193+
print(f" {label_a:<50}: {mean_score_a:>10.4f}")
194+
print(f" {label_b:<50}: {mean_score_b:>10.4f}")
195+
print(f" {'p_value':<50}: {p_value:>10.4f}")
196+
197+
if p_value < 0.05:
198+
print(" p_value<0.05 so this result is statistically significant")
199+
# Note that the logic below is incorrect if the mean scores are equal, but that can't be true if p<1.
200+
higher_model_id = label_a if mean_score_a >= mean_score_b else label_b
201+
print(
202+
f" You can conclude that {higher_model_id} generation is better on data of this sort"
203+
)
204+
else:
205+
import math
206+
207+
print(" p_value>=0.05 so this result is NOT statistically significant.")
208+
print(
209+
" You can conclude that there is not enough data to tell which is better."
210+
)
211+
num_samples = len(scores_a)
212+
margin_of_error = 1 / math.sqrt(num_samples)
213+
print(
214+
f" Note that this data includes {num_samples} questions which typically produces a margin of error of around +/-{margin_of_error:.1%}."
215+
)
216+
print(" So the two are probably roughly within that margin of error or so.")
217+
218+
219+
def get_metrics(all_scores: dict):
220+
for scores_for_dataset in all_scores.values():
221+
for scores_for_condition in scores_for_dataset.values():
222+
for scores_for_question in scores_for_condition.values():
223+
metrics = list(scores_for_question.keys())
224+
metrics.sort()
225+
return metrics
226+
return []
227+
228+
229+
def print_scores(all_scores: dict):
230+
metrics = get_metrics(all_scores)
231+
for dataset_name, scores_for_dataset in all_scores.items():
232+
condition_labels = list(scores_for_dataset.keys())
233+
condition_labels.sort()
234+
for metric in metrics:
235+
overview_label = f"{dataset_name} {metric}"
236+
for label_a, label_b in itertools.combinations(condition_labels, 2):
237+
scores_for_label_a = scores_for_dataset[label_a]
238+
scores_for_label_b = scores_for_dataset[label_b]
239+
scores_a = [
240+
score_group[metric] for score_group in scores_for_label_a.values()
241+
]
242+
scores_b = [
243+
score_group[metric] for score_group in scores_for_label_b.values()
244+
]
245+
print_stats_significance(
246+
scores_a, scores_b, overview_label, label_a, label_b
247+
)
248+
print("\n")
249+
250+
251+
def evaluate_retrieval(
252+
llama_stack_client: LlamaStackAsLibraryClient,
253+
datasets: list[str],
254+
custom_datasets_urls: list[str],
255+
batch_size: int,
256+
vector_db_provider_id: str,
257+
embedding_models: list[str],
258+
):
259+
results_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "results")
260+
all_scores = {}
261+
262+
custom_datasets_pairs = {}
263+
if custom_datasets_urls:
264+
custom_datasets_pairs = {
265+
dataset_name: custom_datasets_urls[i]
266+
for i, dataset_name in enumerate(datasets)
267+
}
268+
269+
for dataset_name in datasets:
270+
all_scores[dataset_name] = {}
271+
corpus, queries, qrels = load_beir_dataset(dataset_name, custom_datasets_pairs)
272+
for embedding_model in embedding_models:
273+
print(
274+
f"\n====================== {dataset_name}, {embedding_model} ======================"
275+
)
276+
print(f"Ingesting {dataset_name}, {embedding_model}")
277+
vector_db_id = inject_documents(
278+
llama_stack_client,
279+
corpus,
280+
batch_size,
281+
vector_db_provider_id,
282+
embedding_model,
283+
)
284+
285+
query_config = RAGQueryConfig(max_chunks=10, mode="vector").model_dump()
286+
retriever = LlamaStackRAGRetriever(vector_db_id, query_config, top_k=10)
287+
288+
print("Retrieving")
289+
results = retriever.retrieve(queries, top_k=10)
290+
291+
print("Scoring")
292+
k_values = [5, 10]
293+
294+
# This is a subset of the evaluation metrics used in beir.retrieval.evaluation.
295+
# It formulates the metric strings at https://github.com/beir-cellar/beir/blob/main/beir/retrieval/evaluation.py#L61
296+
# and then calls pytrec_eval.RelevanceEvaluator. We call pytrec_eval.RelevanceEvaluator directly using some of
297+
# those strings because we want not only the overall averages (which beir.retrieval.evaluation provides) but also
298+
# the scores for each question so we can compute statistical significance.
299+
map_string = "map_cut." + ",".join([str(k) for k in k_values])
300+
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
301+
metrics_strings = {ndcg_string, map_string}
302+
303+
evaluator = pytrec_eval.RelevanceEvaluator(qrels, metrics_strings)
304+
scores = evaluator.evaluate(results)
305+
306+
all_scores[dataset_name][embedding_model] = scores
307+
308+
os.makedirs(results_dir, exist_ok=True)
309+
util.save_runfile(
310+
os.path.join(
311+
results_dir,
312+
f"{dataset_name}-{vector_db_id}-{embedding_model}.run.trec",
313+
),
314+
results,
315+
)
316+
print(f"All results in {results_dir}\n")
317+
318+
return all_scores
319+
320+
321+
if __name__ == "__main__":
322+
args = parse_args()
323+
324+
# A check for when custom dataset urls are set they are compared with the number of dataset names
325+
if args.custom_datasets_urls and len(args.custom_datasets_urls) != len(
326+
args.dataset_names
327+
):
328+
raise ValueError(
329+
f"Number of custom dataset URLs ({len(args.custom_datasets_urls)}) must match "
330+
f"number of dataset names ({len(args.dataset_names)}). "
331+
f"Got URLs: {args.custom_datasets_urls}, dataset names: {args.dataset_names}"
332+
)
333+
334+
llama_stack_client = LlamaStackAsLibraryClient("./run.yaml")
335+
llama_stack_client.initialize()
336+
all_scores = evaluate_retrieval(
337+
llama_stack_client,
338+
args.dataset_names,
339+
args.custom_datasets_urls,
340+
args.batch_size,
341+
"milvus",
342+
args.embedding_models,
343+
)
344+
print_scores(all_scores)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
beir>=2.2.0
2+
llama-stack>=0.2.13
3+
pymilvus>=2.5.11

0 commit comments

Comments
 (0)