Skip to content

Commit 720104a

Browse files
committed
refactor: add flags for custom datasets and update requirements & run file
1 parent c80cc11 commit 720104a

File tree

3 files changed

+142
-56
lines changed

3 files changed

+142
-56
lines changed

benchmarks/llama-stack-rag-with-beir/benchmark_beir_ls_vs_no_ls.py

Lines changed: 97 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
6+
import argparse
77
import itertools
88
import logging
99
import os
@@ -57,6 +57,35 @@
5757
# makes this assessment faster. Running on more datasets would make it more robust. So it is a tricky trade-off.
5858
# See a full list of available datasets at https://github.com/beir-cellar/beir?tab=readme-ov-file#beers-available-datasets
5959
DATASETS = ["scifact"]
60+
DEFAULT_CUSTOM_DATASETS_URLS = []
61+
DEFAULT_BATCH_SIZE = 150
62+
63+
64+
def parse_args():
65+
parser = argparse.ArgumentParser(description="Benchmark embedding models with BEIR datasets")
66+
67+
parser.add_argument(
68+
"--dataset-names",
69+
nargs="+",
70+
default=DATASETS,
71+
help=f"List of BEIR datasets to evaluate (default: {DATASETS})"
72+
)
73+
74+
parser.add_argument(
75+
"--custom-datasets-urls",
76+
nargs="+",
77+
default=DEFAULT_CUSTOM_DATASETS_URLS,
78+
help=f"Custom URLs for datasets (default: {DEFAULT_CUSTOM_DATASETS_URLS})"
79+
)
80+
81+
parser.add_argument(
82+
"--batch-size",
83+
type=int,
84+
default=DEFAULT_BATCH_SIZE,
85+
help=f"Batch size for injecting documents (default: {DEFAULT_BATCH_SIZE})"
86+
)
87+
88+
return parser.parse_args()
6089

6190
logger = logging.getLogger(__name__)
6291
logger.setLevel(logging.INFO)
@@ -68,12 +97,15 @@
6897

6998

7099
# Load BEIR dataset
71-
def load_beir_dataset(dataset_name: str):
72-
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
100+
def load_beir_dataset(dataset_name: str, custom_datasets_pairs: dict):
101+
if custom_datasets_pairs == {}:
102+
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
103+
else:
104+
url = custom_datasets_pairs[dataset_name]
105+
73106
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
74-
75107
data_path = os.path.join(out_dir, dataset_name)
76-
print(data_path)
108+
77109
if not os.path.isdir(data_path):
78110
data_path = util.download_and_unzip(url, out_dir)
79111

@@ -83,7 +115,7 @@ def load_beir_dataset(dataset_name: str):
83115

84116
# Inject documents into LlamaStack vector database
85117
def inject_documents_llama_stack(
86-
llama_stack_client, corpus, vector_db_provider_id, embedding_model_id, chunk_size_in_tokens
118+
llama_stack_client, corpus, vector_db_provider_id, embedding_model_id, chunk_size_in_tokens, batch_size
87119
):
88120
vector_db_id = f"beir-rag-eval-{uuid.uuid4().hex}"
89121

@@ -100,20 +132,27 @@ def inject_documents_llama_stack(
100132
embedding_dimension=embedding_dimension,
101133
)
102134

103-
# Convert corpus into Documents
104-
documents = [
105-
Document(
106-
document_id=doc_id,
107-
content=data["title"] + " " + data["text"],
108-
mime_type="text/plain",
109-
metadata={},
110-
)
111-
for doc_id, data in corpus.items()
112-
]
135+
# Convert corpus into Documents and process in batches
136+
corpus_items = list(corpus.items())
137+
total_docs = len(corpus_items)
138+
139+
for i in range(0, total_docs, batch_size):
140+
batch_items = corpus_items[i:i + batch_size]
141+
documents_batch = [
142+
Document(
143+
document_id=doc_id,
144+
content=data["title"] + " " + data["text"],
145+
mime_type="text/plain",
146+
metadata={},
147+
)
148+
for doc_id, data in batch_items
149+
]
113150

114-
llama_stack_client.tool_runtime.rag_tool.insert(
115-
documents=documents, vector_db_id=vector_db_id, chunk_size_in_tokens=chunk_size_in_tokens, timeout=36000
116-
)
151+
print(f"Inserting batch {i//batch_size + 1}/{(total_docs + batch_size - 1)//batch_size} ({len(documents_batch)} docs)")
152+
153+
llama_stack_client.tool_runtime.rag_tool.insert(
154+
documents=documents_batch, vector_db_id=vector_db_id, chunk_size_in_tokens=chunk_size_in_tokens, timeout=36000
155+
)
117156

118157
return vector_db_id
119158

@@ -182,7 +221,7 @@ def make_overlapped_chunks(
182221

183222

184223
# Inject documents directly into a Milvus lite vector database using the Milvus APIs
185-
def inject_documents_milvus(corpus, embedding_model_id, chunk_size_in_tokens):
224+
def inject_documents_milvus(corpus, embedding_model_id, chunk_size_in_tokens, batch_size):
186225
collection_name = f"beir_eval_{uuid.uuid4().hex}"
187226

188227
embedding_model = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_model_id, device="mps")
@@ -192,14 +231,23 @@ def inject_documents_milvus(corpus, embedding_model_id, chunk_size_in_tokens):
192231
milvus_client = MilvusClient(db_file)
193232
milvus_client.create_collection(collection_name=collection_name, dimension=int(embedding_dimension), auto_id=True)
194233

195-
documents = []
196-
for doc_id, data in corpus.items():
197-
full_text = data["title"] + " " + data["text"]
198-
chunks = llama_stack_style_chunker(full_text, chunk_size_in_tokens)
199-
for chunk in chunks:
200-
documents.append({"doc_id": doc_id, "vector": embedding_model.encode_documents([chunk])[0], "text": chunk})
234+
# Convert corpus into list and process in batches
235+
corpus_items = list(corpus.items())
236+
total_docs = len(corpus_items)
237+
238+
for i in range(0, total_docs, batch_size):
239+
batch_items = corpus_items[i:i + batch_size]
240+
documents_batch = []
241+
242+
for doc_id, data in batch_items:
243+
full_text = data["title"] + " " + data["text"]
244+
chunks = llama_stack_style_chunker(full_text, chunk_size_in_tokens)
245+
for chunk in chunks:
246+
documents_batch.append({"doc_id": doc_id, "vector": embedding_model.encode_documents([chunk])[0], "text": chunk})
247+
248+
print(f"Inserting batch {i//batch_size + 1}/{(total_docs + batch_size - 1)//batch_size} ({len(documents_batch)} chunks)")
249+
milvus_client.insert(collection_name=collection_name, data=documents_batch)
201250

202-
milvus_client.insert(collection_name=collection_name, data=documents)
203251
return milvus_client, collection_name, embedding_model
204252

205253

@@ -362,26 +410,33 @@ def print_scores(all_scores):
362410
def evaluate_retrieval_with_and_without_llama_stack(
363411
llama_stack_client,
364412
datasets,
413+
custom_datasets_urls,
365414
vector_db_provider_id,
366415
embedding_model_id,
416+
batch_size,
367417
chunk_size_in_tokens=512,
368418
number_of_search_results=10,
369419
save_files=False,
370420
):
371421
all_scores = {}
372422
results_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "results")
423+
424+
custom_datasets_pairs = {}
425+
if custom_datasets_urls:
426+
custom_datasets_pairs = {dataset_name: custom_datasets_urls[i] for i, dataset_name in enumerate(datasets)}
427+
373428
for dataset_name in datasets:
374429
all_scores[dataset_name] = {}
375-
corpus, queries, qrels = load_beir_dataset(dataset_name)
376-
430+
corpus, queries, qrels = load_beir_dataset(dataset_name, custom_datasets_pairs)
431+
377432
# Uncomment this line to select only a few documents for debugging
378433
#corpus = pick_arbitrary_pairs(corpus)
379434

380435
retrievers = {}
381436

382437
logger.info(f"Ingesting {dataset_name}, LlamaStackRAGRetriever")
383438
vector_db_id = inject_documents_llama_stack(
384-
llama_stack_client, corpus, vector_db_provider_id, embedding_model_id, chunk_size_in_tokens
439+
llama_stack_client, corpus, vector_db_provider_id, embedding_model_id, chunk_size_in_tokens, batch_size
385440
)
386441

387442
# We set max_tokens_in_context=chunk_size_in_tokens*number_of_search_results so that we won't get errors saying that we have too many tokens.
@@ -395,7 +450,7 @@ def evaluate_retrieval_with_and_without_llama_stack(
395450

396451
print(f"Ingesting {dataset_name}, MilvusRetriever")
397452
milvus_client, collection_name, embedding_model = inject_documents_milvus(
398-
corpus, embedding_model_id, chunk_size_in_tokens
453+
corpus, embedding_model_id, chunk_size_in_tokens, batch_size
399454
)
400455
milvus_retriever = MilvusRetriever(
401456
milvus_client, collection_name, embedding_model, top_k=number_of_search_results
@@ -462,11 +517,21 @@ def pick_arbitrary_pairs(input_dict, num_pairs=5):
462517

463518

464519
if __name__ == "__main__":
520+
args = parse_args()
521+
522+
# A check for when custom dataset urls are set they are compared with the number of dataset names
523+
if args.custom_datasets_urls and len(args.custom_datasets_urls) != len(args.dataset_names):
524+
raise ValueError(
525+
f"Number of custom dataset URLs ({len(args.custom_datasets_urls)}) must match "
526+
f"number of dataset names ({len(args.dataset_names)}). "
527+
f"Got URLs: {args.custom_datasets_urls}, dataset names: {args.dataset_names}"
528+
)
529+
465530
llama_stack_client = LlamaStackAsLibraryClient("./run.yaml")
466531
llama_stack_client.initialize()
467532

468533
all_scores = evaluate_retrieval_with_and_without_llama_stack(
469-
llama_stack_client, DATASETS, "milvus", "ibm-granite/granite-embedding-125m-english"
534+
llama_stack_client, args.dataset_names, args.custom_datasets_urls, "milvus", "ibm-granite/granite-embedding-125m-english", args.batch_size
470535
)
471536
has_significant_difference = print_scores(all_scores)
472537
if has_significant_difference:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
llama-stack>=0.2.13
2+
pymilvus>=2.5.12
3+
pytrec-eval>=0.5
4+
beir>=2.2.0
5+
pymilvus-model>=0.3.2

benchmarks/llama-stack-rag-with-beir/run.yaml

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
version: '2'
1+
version: 2
22
image_name: ollama
33
apis:
44
- agents
55
- datasetio
66
- eval
7+
- files
78
- inference
9+
- post_training
810
- safety
911
- scoring
1012
- telemetry
@@ -15,19 +17,20 @@ providers:
1517
- provider_id: ollama
1618
provider_type: remote::ollama
1719
config:
18-
url: ${env.OLLAMA_URL:http://localhost:11434}
20+
url: ${env.OLLAMA_URL:=http://localhost:11434}
21+
raise_on_connect_error: true
1922
- provider_id: sentence-transformers
2023
provider_type: inline::sentence-transformers
2124
config: {}
2225
vector_io:
2326
- provider_id: milvus
2427
provider_type: inline::milvus
2528
config:
26-
db_path: ${env.MILVUS_STORE_DIR:~/.llama/distributions/ollama}/milvus.db
27-
- provider_id: sqlite-vec
28-
provider_type: inline::sqlite-vec
29-
config:
30-
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec_store.db
29+
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ollama/milvus.db}
30+
kvstore:
31+
type: sqlite
32+
namespace: null
33+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/${env.MILVUS_KVSTORE_DB_PATH:=~/.llama/distributions/ollama/milvus_registry.db}
3134
safety:
3235
- provider_id: llama-guard
3336
provider_type: inline::llama-guard
@@ -40,37 +43,33 @@ providers:
4043
persistence_store:
4144
type: sqlite
4245
namespace: null
43-
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
44-
telemetry:
45-
- provider_id: meta-reference
46-
provider_type: inline::meta-reference
47-
config:
48-
service_name: ${env.OTEL_SERVICE_NAME:}
49-
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
50-
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
46+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
47+
responses_store:
48+
type: sqlite
49+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
5150
eval:
5251
- provider_id: meta-reference
5352
provider_type: inline::meta-reference
5453
config:
5554
kvstore:
5655
type: sqlite
5756
namespace: null
58-
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
57+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/meta_reference_eval.db
5958
datasetio:
6059
- provider_id: huggingface
6160
provider_type: remote::huggingface
6261
config:
6362
kvstore:
6463
type: sqlite
6564
namespace: null
66-
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db
65+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
6766
- provider_id: localfs
6867
provider_type: inline::localfs
6968
config:
7069
kvstore:
7170
type: sqlite
7271
namespace: null
73-
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
72+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
7473
scoring:
7574
- provider_id: basic
7675
provider_type: inline::basic
@@ -81,17 +80,32 @@ providers:
8180
- provider_id: braintrust
8281
provider_type: inline::braintrust
8382
config:
84-
openai_api_key: ${env.OPENAI_API_KEY:}
83+
openai_api_key: ${env.OPENAI_API_KEY:+}
84+
files:
85+
- provider_id: meta-reference-files
86+
provider_type: inline::localfs
87+
config:
88+
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ollama/files}
89+
metadata_store:
90+
type: sqlite
91+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/files_metadata.db
92+
post_training:
93+
- provider_id: huggingface
94+
provider_type: inline::huggingface
95+
config:
96+
checkpoint_format: huggingface
97+
distributed_backend: null
98+
device: cpu
8599
tool_runtime:
86100
- provider_id: brave-search
87101
provider_type: remote::brave-search
88102
config:
89-
api_key: ${env.BRAVE_SEARCH_API_KEY:}
103+
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
90104
max_results: 3
91105
- provider_id: tavily-search
92106
provider_type: remote::tavily-search
93107
config:
94-
api_key: ${env.TAVILY_SEARCH_API_KEY:}
108+
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
95109
max_results: 3
96110
- provider_id: rag-runtime
97111
provider_type: inline::rag-runtime
@@ -102,10 +116,13 @@ providers:
102116
- provider_id: wolfram-alpha
103117
provider_type: remote::wolfram-alpha
104118
config:
105-
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
119+
api_key: ${env.WOLFRAM_ALPHA_API_KEY:+}
106120
metadata_store:
107121
type: sqlite
108-
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
122+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
123+
inference_store:
124+
type: sqlite
125+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/inference_store.db
109126
models:
110127
- metadata: {}
111128
model_id: ${env.INFERENCE_MODEL}
@@ -143,4 +160,3 @@ tool_groups:
143160
provider_id: wolfram-alpha
144161
server:
145162
port: 8321
146-
disable_ipv6: false

0 commit comments

Comments
 (0)