Skip to content

Commit 49e0772

Browse files
authored
Add local benchmarking script (#459)
1 parent 14fd1a8 commit 49e0772

File tree

3 files changed

+334
-1
lines changed

3 files changed

+334
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,6 @@ documentation/reference
107107

108108
# CMake
109109
*/cmake-build-*
110+
111+
# Benchmarking temporary files
112+
tmp/
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
# Used to benchmark ingestion and querying running locally. First downloads SIFT and then
2+
# benchmarks ingestion and querying.
3+
#
4+
# To run:
5+
# - ~/repo/TileDB-Vector-Search pip install ".[benchmarks]"
6+
# - ~/repo/TileDB-Vector-Search python apis/python/test/local-benchmarks.py
7+
8+
import os
9+
import shutil
10+
import tarfile
11+
import time
12+
import urllib.request
13+
from enum import Enum
14+
15+
import matplotlib
16+
import matplotlib.pyplot as plt
17+
from common import accuracy
18+
from common import get_groundtruth_ivec
19+
20+
from tiledb.vector_search.ingestion import TrainingSamplingPolicy
21+
from tiledb.vector_search.ingestion import ingest
22+
from tiledb.vector_search.utils import load_fvecs
23+
24+
matplotlib.use("Agg")
25+
26+
USE_SIFT_SMALL = True
27+
28+
SIFT_URI = (
29+
"ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz"
30+
if USE_SIFT_SMALL
31+
else "ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz"
32+
)
33+
SIFT_FOLDER_NAME = "siftsmall" if USE_SIFT_SMALL else "sift"
34+
35+
TEMP_DIR = os.path.join(os.path.dirname(__file__), "tmp")
36+
os.makedirs(TEMP_DIR, exist_ok=True)
37+
38+
SIFT_DOWNLOAD_PATH = os.path.join(
39+
TEMP_DIR, "siftsmall.tar.gz" if USE_SIFT_SMALL else "sift.tar.gz"
40+
)
41+
SIFT_BASE_PATH = os.path.join(
42+
TEMP_DIR,
43+
SIFT_FOLDER_NAME,
44+
"siftsmall_base.fvecs" if USE_SIFT_SMALL else "sift_base.fvecs",
45+
)
46+
SIFT_QUERIES_PATH = os.path.join(
47+
TEMP_DIR,
48+
SIFT_FOLDER_NAME,
49+
"siftsmall_query.fvecs" if USE_SIFT_SMALL else "sift_query.fvecs",
50+
)
51+
SIFT_GROUNDTRUTH_PATH = os.path.join(
52+
TEMP_DIR,
53+
SIFT_FOLDER_NAME,
54+
"siftsmall_groundtruth.ivecs" if USE_SIFT_SMALL else "sift_groundtruth.ivecs",
55+
)
56+
57+
58+
class TimerMode(Enum):
59+
INGESTION = "ingestion"
60+
QUERY = "query"
61+
62+
63+
class Timer:
64+
def __init__(self):
65+
self.current_timers = {}
66+
67+
self.keyToTimes = {}
68+
self.tagToAccuracies = {}
69+
70+
def start(self, tag, mode):
71+
key = f"{tag}_{mode.value}"
72+
if key in self.current_timers:
73+
raise ValueError(f"Timer {tag} already started.")
74+
self.current_timers[key] = time.time()
75+
76+
def stop(self, tag, mode):
77+
key = f"{tag}_{mode.value}"
78+
if key not in self.current_timers:
79+
raise ValueError(f"Timer {tag} not started.")
80+
elapsed = time.time() - self.current_timers[key]
81+
self.current_timers.pop(key)
82+
83+
if key not in self.keyToTimes:
84+
self.keyToTimes[key] = []
85+
self.keyToTimes[key].append(elapsed)
86+
return elapsed
87+
88+
def accuracy(self, tag, acc):
89+
if tag not in self.tagToAccuracies:
90+
self.tagToAccuracies[tag] = []
91+
self.tagToAccuracies[tag].append(acc)
92+
return acc
93+
94+
def summarize_data(self):
95+
summary = {}
96+
for key, intervals in self.keyToTimes.items():
97+
tag, mode = key.rsplit("_", 1)
98+
if tag not in summary:
99+
summary[tag] = {
100+
"ingestion": {"total_time": 0, "count": 0, "times": []},
101+
"query": {
102+
"total_time": 0,
103+
"count": 0,
104+
"accuracies": [],
105+
"times": [],
106+
},
107+
}
108+
total_time = sum(intervals)
109+
count = len(intervals)
110+
if mode == "ingestion":
111+
summary[tag]["ingestion"]["total_time"] += total_time
112+
summary[tag]["ingestion"]["count"] += count
113+
summary[tag]["ingestion"]["times"] = intervals
114+
elif mode == "query":
115+
summary[tag]["query"]["total_time"] += total_time
116+
summary[tag]["query"]["count"] += count
117+
summary[tag]["query"]["times"] = intervals
118+
119+
for tag, accuracies in self.tagToAccuracies.items():
120+
if tag in summary:
121+
summary[tag]["query"]["accuracies"] = accuracies
122+
123+
return summary
124+
125+
def summarize(self):
126+
summary = self.summarize_data()
127+
summary_str = ""
128+
for tag, data in summary.items():
129+
summary_str += f"{tag}\n"
130+
if "ingestion" in data:
131+
summary_str += f" Ingestion (count: {data['ingestion']['count']}):\n"
132+
summary_str += f" Average Time: {data['ingestion']['total_time'] / data['ingestion']['count']:.4f} seconds\n"
133+
if "query" in data:
134+
summary_str += f" Query (count: {data['query']['count']}):\n"
135+
summary_str += f" Average Time: {data['query']['total_time'] / data['query']['count']:.4f} seconds\n"
136+
if data["query"]["accuracies"]:
137+
summary_str += f" Average Accuracy: {sum(data['query']['accuracies']) / len(data['query']['accuracies']):.4f}\n"
138+
summary_str += "\n"
139+
return summary_str
140+
141+
def create_charts(self):
142+
summary = self.summarize_data()
143+
144+
# Plot ingestion.
145+
plt.figure(figsize=(20, 12))
146+
plt.xlabel("Average Query Accuracy")
147+
plt.ylabel("Time (seconds)")
148+
plt.title("Ingestion Time vs Average Query Accuracy")
149+
for tag, data in summary.items():
150+
ingestion_times = []
151+
average_accuracy = sum(data["query"]["accuracies"]) / len(
152+
data["query"]["accuracies"]
153+
)
154+
for i in range(data["ingestion"]["count"]):
155+
ingestion_times.append(
156+
(data["ingestion"]["times"][i], average_accuracy)
157+
)
158+
x, y = zip(*ingestion_times)
159+
plt.scatter(y, x, marker="o", label=tag)
160+
161+
plt.legend()
162+
plt.savefig(os.path.join(TEMP_DIR, "ingestion_time_vs_accuracy.png"))
163+
plt.close()
164+
165+
# Plot query.
166+
plt.figure(figsize=(20, 12))
167+
plt.xlabel("Accuracy")
168+
plt.ylabel("Time (seconds)")
169+
plt.title("Query Time vs Accuracy")
170+
for tag, data in summary.items():
171+
query_times = []
172+
for i in range(data["query"]["count"]):
173+
query_times.append(
174+
(data["query"]["times"][i], data["query"]["accuracies"][i])
175+
)
176+
x, y = zip(*query_times)
177+
plt.plot(y, x, marker="o", label=tag)
178+
179+
plt.legend()
180+
plt.savefig(os.path.join(TEMP_DIR, "query_time_vs_accuracy.png"))
181+
plt.close()
182+
183+
184+
def download_and_extract(url, download_path, extract_path):
185+
if os.path.exists(download_path):
186+
print(
187+
f"Skipping download of {url} to {download_path} because it already exists."
188+
)
189+
else:
190+
print(f"Downloading {url} to {download_path}.")
191+
urllib.request.urlretrieve(url, download_path)
192+
print("Finished download.")
193+
194+
print("Extracting files.")
195+
with tarfile.open(download_path, "r:gz") as tar:
196+
tar.extractall(path=extract_path)
197+
print("Finished extracting files.")
198+
199+
200+
def benchmark_ivf_flat():
201+
index_type = "IVF_FLAT"
202+
timer = Timer()
203+
204+
k = 100
205+
queries = load_fvecs(SIFT_QUERIES_PATH)
206+
gt_i, gt_d = get_groundtruth_ivec(SIFT_GROUNDTRUTH_PATH, k=k, nqueries=len(queries))
207+
208+
for partitions in [20, 50, 100, 200]:
209+
tag = f"{index_type}_partitions={partitions}"
210+
print(f"Running {tag}")
211+
212+
index_uri = os.path.join(TEMP_DIR, f"index_{index_type}")
213+
if os.path.exists(index_uri):
214+
shutil.rmtree(index_uri)
215+
216+
timer.start(tag, TimerMode.INGESTION)
217+
index = ingest(
218+
index_type=index_type,
219+
index_uri=index_uri,
220+
source_uri=SIFT_BASE_PATH,
221+
partitions=partitions,
222+
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
223+
)
224+
ingest_time = timer.stop(tag, TimerMode.INGESTION)
225+
226+
for nprobe in [1, 2, 3, 4, 5, 10, 20]:
227+
timer.start(tag, TimerMode.QUERY)
228+
_, result = index.query(queries, k=k, nprobe=nprobe)
229+
query_time = timer.stop(tag, TimerMode.QUERY)
230+
acc = timer.accuracy(tag, accuracy(result, gt_i))
231+
print(
232+
f"Finished {tag} with nprobe={nprobe}. Ingestion: {ingest_time:.4f}s. Query: {query_time:.4f}s. Accuracy: {acc:.4f}."
233+
)
234+
235+
print(timer.summarize())
236+
timer.create_charts()
237+
238+
239+
def benchmark_vamana():
240+
index_type = "VAMANA"
241+
timer = Timer()
242+
243+
k = 100
244+
queries = load_fvecs(SIFT_QUERIES_PATH)
245+
gt_i, gt_d = get_groundtruth_ivec(SIFT_GROUNDTRUTH_PATH, k=k, nqueries=len(queries))
246+
247+
for l_build in [10, 25, 40]:
248+
for r_max_degree in [10, 25]:
249+
tag = f"{index_type}_l_build={l_build}_r_max_degree={r_max_degree}"
250+
print(f"Running {tag}")
251+
252+
index_uri = os.path.join(TEMP_DIR, f"index_{index_type}")
253+
if os.path.exists(index_uri):
254+
shutil.rmtree(index_uri)
255+
256+
timer.start(tag, TimerMode.INGESTION)
257+
index = ingest(
258+
index_type=index_type,
259+
index_uri=index_uri,
260+
source_uri=SIFT_BASE_PATH,
261+
l_build=l_build,
262+
r_max_degree=r_max_degree,
263+
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
264+
)
265+
ingest_time = timer.stop(tag, TimerMode.INGESTION)
266+
267+
for l_search in [k, k + 50, k + 100, k + 200, k + 400]:
268+
timer.start(tag, TimerMode.QUERY)
269+
_, result = index.query(queries, k=k, l_search=l_search)
270+
query_time = timer.stop(tag, TimerMode.QUERY)
271+
acc = timer.accuracy(tag, accuracy(result, gt_i))
272+
print(
273+
f"Finished {tag} with l_search={l_search}. Ingestion: {ingest_time:.4f}s. Query: {query_time:.4f}s. Accuracy: {acc:.4f}."
274+
)
275+
276+
print(timer.summarize())
277+
timer.create_charts()
278+
279+
280+
def benchmark_ivf_pq():
281+
index_type = "IVF_PQ"
282+
timer = Timer()
283+
284+
k = 100
285+
queries = load_fvecs(SIFT_QUERIES_PATH)
286+
dimensions = queries.shape[1]
287+
gt_i, gt_d = get_groundtruth_ivec(SIFT_GROUNDTRUTH_PATH, k=k, nqueries=len(queries))
288+
289+
for partitions in [50]:
290+
for num_subspaces in [dimensions / 2, dimensions / 4, dimensions / 8]:
291+
tag = f"{index_type}_partitions={partitions}_num_subspaces={num_subspaces}"
292+
print(f"Running {tag}")
293+
294+
index_uri = os.path.join(TEMP_DIR, f"index_{index_type}")
295+
if os.path.exists(index_uri):
296+
shutil.rmtree(index_uri)
297+
298+
timer.start(tag, TimerMode.INGESTION)
299+
index = ingest(
300+
index_type=index_type,
301+
index_uri=index_uri,
302+
source_uri=SIFT_BASE_PATH,
303+
partitions=partitions,
304+
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
305+
num_subspaces=num_subspaces,
306+
)
307+
ingest_time = timer.stop(tag, TimerMode.INGESTION)
308+
309+
for nprobe in [5, 10, 20, 40, 60]:
310+
timer.start(tag, TimerMode.QUERY)
311+
_, result = index.query(queries, k=k, nprobe=nprobe)
312+
query_time = timer.stop(tag, TimerMode.QUERY)
313+
acc = timer.accuracy(tag, accuracy(result, gt_i))
314+
print(
315+
f"Finished {tag} with nprobe={nprobe}. Ingestion: {ingest_time:.4f}s. Query: {query_time:.4f}s. Accuracy: {acc:.4f}."
316+
)
317+
318+
print(timer.summarize())
319+
timer.create_charts()
320+
321+
322+
def main():
323+
download_and_extract(SIFT_URI, SIFT_DOWNLOAD_PATH, TEMP_DIR)
324+
325+
# benchmark_ivf_flat()
326+
benchmark_vamana()
327+
# benchmark_ivf_pq()
328+
329+
330+
main()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828
[project.optional-dependencies]
2929
test = ["nbmake", "pytest<8.0.0", "pytest-xdist"]
3030
formatting = ["pre-commit"]
31-
benchmarks = ["boto3", "paramiko"]
31+
benchmarks = ["boto3", "paramiko", "matplotlib"]
3232

3333
[project.urls]
3434
homepage = "https://tiledb.com"

0 commit comments

Comments
 (0)