33#
44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
6-
6+ import argparse
77import itertools
88import logging
99import os
5757# makes this assessment faster. Running on more datasets would make it more robust. So it is a tricky trade-off.
5858# See a full list of available datasets at https://github.com/beir-cellar/beir?tab=readme-ov-file#beers-available-datasets
5959DATASETS = ["scifact" ]
60+ DEFAULT_CUSTOM_DATASETS_URLS = []
61+ DEFAULT_BATCH_SIZE = 150
62+
63+
64+ def parse_args ():
65+ parser = argparse .ArgumentParser (description = "Benchmark embedding models with BEIR datasets" )
66+
67+ parser .add_argument (
68+ "--dataset-names" ,
69+ nargs = "+" ,
70+ default = DATASETS ,
71+ help = f"List of BEIR datasets to evaluate (default: { DATASETS } )"
72+ )
73+
74+ parser .add_argument (
75+ "--custom-datasets-urls" ,
76+ nargs = "+" ,
77+ default = DEFAULT_CUSTOM_DATASETS_URLS ,
78+ help = f"Custom URLs for datasets (default: { DEFAULT_CUSTOM_DATASETS_URLS } )"
79+ )
80+
81+ parser .add_argument (
82+ "--batch-size" ,
83+ type = int ,
84+ default = DEFAULT_BATCH_SIZE ,
85+ help = f"Batch size for injecting documents (default: { DEFAULT_BATCH_SIZE } )"
86+ )
87+
88+ return parser .parse_args ()
6089
6190logger = logging .getLogger (__name__ )
6291logger .setLevel (logging .INFO )
6897
6998
7099# Load BEIR dataset
71- def load_beir_dataset (dataset_name : str ):
72- url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{ dataset_name } .zip"
100+ def load_beir_dataset (dataset_name : str , custom_datasets_pairs : dict ):
101+ if custom_datasets_pairs == {}:
102+ url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{ dataset_name } .zip"
103+ else :
104+ url = custom_datasets_pairs [dataset_name ]
105+
73106 out_dir = os .path .join (pathlib .Path (__file__ ).parent .absolute (), "datasets" )
74-
75107 data_path = os .path .join (out_dir , dataset_name )
76- print ( data_path )
108+
77109 if not os .path .isdir (data_path ):
78110 data_path = util .download_and_unzip (url , out_dir )
79111
@@ -83,7 +115,7 @@ def load_beir_dataset(dataset_name: str):
83115
84116# Inject documents into LlamaStack vector database
85117def inject_documents_llama_stack (
86- llama_stack_client , corpus , vector_db_provider_id , embedding_model_id , chunk_size_in_tokens
118+ llama_stack_client , corpus , vector_db_provider_id , embedding_model_id , chunk_size_in_tokens , batch_size
87119):
88120 vector_db_id = f"beir-rag-eval-{ uuid .uuid4 ().hex } "
89121
@@ -100,20 +132,27 @@ def inject_documents_llama_stack(
100132 embedding_dimension = embedding_dimension ,
101133 )
102134
103- # Convert corpus into Documents
104- documents = [
105- Document (
106- document_id = doc_id ,
107- content = data ["title" ] + " " + data ["text" ],
108- mime_type = "text/plain" ,
109- metadata = {},
110- )
111- for doc_id , data in corpus .items ()
112- ]
135+ # Convert corpus into Documents and process in batches
136+ corpus_items = list (corpus .items ())
137+ total_docs = len (corpus_items )
138+
139+ for i in range (0 , total_docs , batch_size ):
140+ batch_items = corpus_items [i :i + batch_size ]
141+ documents_batch = [
142+ Document (
143+ document_id = doc_id ,
144+ content = data ["title" ] + " " + data ["text" ],
145+ mime_type = "text/plain" ,
146+ metadata = {},
147+ )
148+ for doc_id , data in batch_items
149+ ]
113150
114- llama_stack_client .tool_runtime .rag_tool .insert (
115- documents = documents , vector_db_id = vector_db_id , chunk_size_in_tokens = chunk_size_in_tokens , timeout = 36000
116- )
151+ print (f"Inserting batch { i // batch_size + 1 } /{ (total_docs + batch_size - 1 )// batch_size } ({ len (documents_batch )} docs)" )
152+
153+ llama_stack_client .tool_runtime .rag_tool .insert (
154+ documents = documents_batch , vector_db_id = vector_db_id , chunk_size_in_tokens = chunk_size_in_tokens , timeout = 36000
155+ )
117156
118157 return vector_db_id
119158
@@ -182,7 +221,7 @@ def make_overlapped_chunks(
182221
183222
184223# Inject documents directly into a Milvus lite vector database using the Milvus APIs
185- def inject_documents_milvus (corpus , embedding_model_id , chunk_size_in_tokens ):
224+ def inject_documents_milvus (corpus , embedding_model_id , chunk_size_in_tokens , batch_size ):
186225 collection_name = f"beir_eval_{ uuid .uuid4 ().hex } "
187226
188227 embedding_model = model .dense .SentenceTransformerEmbeddingFunction (model_name = embedding_model_id , device = "mps" )
@@ -192,14 +231,23 @@ def inject_documents_milvus(corpus, embedding_model_id, chunk_size_in_tokens):
192231 milvus_client = MilvusClient (db_file )
193232 milvus_client .create_collection (collection_name = collection_name , dimension = int (embedding_dimension ), auto_id = True )
194233
195- documents = []
196- for doc_id , data in corpus .items ():
197- full_text = data ["title" ] + " " + data ["text" ]
198- chunks = llama_stack_style_chunker (full_text , chunk_size_in_tokens )
199- for chunk in chunks :
200- documents .append ({"doc_id" : doc_id , "vector" : embedding_model .encode_documents ([chunk ])[0 ], "text" : chunk })
234+ # Convert corpus into list and process in batches
235+ corpus_items = list (corpus .items ())
236+ total_docs = len (corpus_items )
237+
238+ for i in range (0 , total_docs , batch_size ):
239+ batch_items = corpus_items [i :i + batch_size ]
240+ documents_batch = []
241+
242+ for doc_id , data in batch_items :
243+ full_text = data ["title" ] + " " + data ["text" ]
244+ chunks = llama_stack_style_chunker (full_text , chunk_size_in_tokens )
245+ for chunk in chunks :
246+ documents_batch .append ({"doc_id" : doc_id , "vector" : embedding_model .encode_documents ([chunk ])[0 ], "text" : chunk })
247+
248+ print (f"Inserting batch { i // batch_size + 1 } /{ (total_docs + batch_size - 1 )// batch_size } ({ len (documents_batch )} chunks)" )
249+ milvus_client .insert (collection_name = collection_name , data = documents_batch )
201250
202- milvus_client .insert (collection_name = collection_name , data = documents )
203251 return milvus_client , collection_name , embedding_model
204252
205253
@@ -362,26 +410,33 @@ def print_scores(all_scores):
362410def evaluate_retrieval_with_and_without_llama_stack (
363411 llama_stack_client ,
364412 datasets ,
413+ custom_datasets_urls ,
365414 vector_db_provider_id ,
366415 embedding_model_id ,
416+ batch_size ,
367417 chunk_size_in_tokens = 512 ,
368418 number_of_search_results = 10 ,
369419 save_files = False ,
370420):
371421 all_scores = {}
372422 results_dir = os .path .join (pathlib .Path (__file__ ).parent .absolute (), "results" )
423+
424+ custom_datasets_pairs = {}
425+ if custom_datasets_urls :
426+ custom_datasets_pairs = {dataset_name : custom_datasets_urls [i ] for i , dataset_name in enumerate (datasets )}
427+
373428 for dataset_name in datasets :
374429 all_scores [dataset_name ] = {}
375- corpus , queries , qrels = load_beir_dataset (dataset_name )
376-
430+ corpus , queries , qrels = load_beir_dataset (dataset_name , custom_datasets_pairs )
431+
377432 # Uncomment this line to select only a few documents for debugging
378433 #corpus = pick_arbitrary_pairs(corpus)
379434
380435 retrievers = {}
381436
382437 logger .info (f"Ingesting { dataset_name } , LlamaStackRAGRetriever" )
383438 vector_db_id = inject_documents_llama_stack (
384- llama_stack_client , corpus , vector_db_provider_id , embedding_model_id , chunk_size_in_tokens
439+ llama_stack_client , corpus , vector_db_provider_id , embedding_model_id , chunk_size_in_tokens , batch_size
385440 )
386441
387442 # We set max_tokens_in_context=chunk_size_in_tokens*number_of_search_results so that we won't get errors saying that we have too many tokens.
@@ -395,7 +450,7 @@ def evaluate_retrieval_with_and_without_llama_stack(
395450
396451 print (f"Ingesting { dataset_name } , MilvusRetriever" )
397452 milvus_client , collection_name , embedding_model = inject_documents_milvus (
398- corpus , embedding_model_id , chunk_size_in_tokens
453+ corpus , embedding_model_id , chunk_size_in_tokens , batch_size
399454 )
400455 milvus_retriever = MilvusRetriever (
401456 milvus_client , collection_name , embedding_model , top_k = number_of_search_results
@@ -462,11 +517,21 @@ def pick_arbitrary_pairs(input_dict, num_pairs=5):
462517
463518
464519if __name__ == "__main__" :
520+ args = parse_args ()
521+
522+ # A check for when custom dataset urls are set they are compared with the number of dataset names
523+ if args .custom_datasets_urls and len (args .custom_datasets_urls ) != len (args .dataset_names ):
524+ raise ValueError (
525+ f"Number of custom dataset URLs ({ len (args .custom_datasets_urls )} ) must match "
526+ f"number of dataset names ({ len (args .dataset_names )} ). "
527+ f"Got URLs: { args .custom_datasets_urls } , dataset names: { args .dataset_names } "
528+ )
529+
465530 llama_stack_client = LlamaStackAsLibraryClient ("./run.yaml" )
466531 llama_stack_client .initialize ()
467532
468533 all_scores = evaluate_retrieval_with_and_without_llama_stack (
469- llama_stack_client , DATASETS , "milvus" , "ibm-granite/granite-embedding-125m-english"
534+ llama_stack_client , args . dataset_names , args . custom_datasets_urls , "milvus" , "ibm-granite/granite-embedding-125m-english" , args . batch_size
470535 )
471536 has_significant_difference = print_scores (all_scores )
472537 if has_significant_difference :
0 commit comments