Skip to content

Commit 173a5fb

Browse files
authored
Improve benchmark script - add other vector search libraries and download full results (#415)
1 parent fb12012 commit 173a5fb

File tree

1 file changed

+101
-21
lines changed

1 file changed

+101
-21
lines changed

src/benchmarks/ann-benchmarks.py

Lines changed: 101 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,63 @@
33
# To run:
44
# - pip install ".[benchmarks]"
55
# - Set up your AWS credentials locally. You can set them in `~/.aws/credentials` to be picked up automatically.
6-
# - Fill in the following details. You can create these in the EC2 console.
7-
# 1. key_name: Your EC2 key pair name.
8-
# 2. key_path: The to your local private key file.
6+
# - Add environment variables which hold the values below. You can create these in the EC2 console.
7+
# 1. TILEDB_EC2_KEY_NAME: Your EC2 key pair name.
8+
# 2. TILEDB_EC2_KEY_PATH: The path to your local private key file.
99
# - Make sure to `chmod 400 /path/to/key.pem` after download.
10-
# - python src/benchmarks/ann-benchmarks.py
10+
# - caffeinate python src/benchmarks/ann-benchmarks.py
1111

1212
import logging
1313
import os
1414
import socket
1515
import time
16+
from datetime import datetime
1617

1718
import boto3
1819
import paramiko
1920

20-
# You must fill these in before running the script:
21-
key_name = "key_name"
22-
key_path = "/path/to/key.pem"
21+
installations = ["tiledb"]
22+
algorithms = [
23+
"tiledb-ivf-flat",
24+
"tiledb-ivf-pq",
25+
"tiledb-flat",
26+
# NOTE(paris): Commented out until Vamana disk space usage is optimized.
27+
# "tiledb-vamana"
28+
]
2329

30+
also_benchmark_others = True
31+
if also_benchmark_others:
32+
# TODO(paris): Some of these are failing so commented out. Investigate and re-enable.
33+
installations += [
34+
# "flann",
35+
# "faiss",
36+
# "hnswlib",
37+
# "weaviate"
38+
# "milvus",
39+
"pgvector"
40+
]
41+
algorithms += [
42+
# "flann",
43+
# "faiss-ivf",
44+
# "faiss-lsh",
45+
# "faiss-ivfpqfs",
46+
# "hnswlib",
47+
# "weaviate",
48+
# "milvus-flat",
49+
# "milvus-ivfflat",
50+
# "milvus-ivfpq",
51+
# "milvus-scann",
52+
# "milvus-hnsw",
53+
"pgvector",
54+
]
55+
56+
# You must set these before running the script:
57+
key_name = os.environ.get("TILEDB_EC2_KEY_NAME")
58+
key_path = os.environ.get("TILEDB_EC2_KEY_PATH")
59+
if key_name is None:
60+
raise ValueError("Please set TILEDB_EC2_KEY_NAME before running.")
61+
if key_path is None:
62+
raise ValueError("Please set TILEDB_EC2_KEY_PATH before running.")
2463
if not os.path.exists(key_path):
2564
raise FileNotFoundError(
2665
f"Key file not found at {key_path}. Please set the correct path before running."
@@ -34,14 +73,20 @@
3473
ami_id = "ami-09e647bf7a368e505"
3574
username = "ec2-user"
3675

37-
# Configure logging
76+
# Configure logging.
3877
logging.basicConfig(level=logging.INFO)
3978
logger = logging.getLogger()
40-
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
79+
80+
# Create a new folder in results_dir with the current date and time.
81+
results_dir = os.path.join(
82+
os.path.dirname(os.path.abspath(__file__)),
83+
"results",
84+
datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
85+
)
4186
os.makedirs(results_dir, exist_ok=True)
87+
88+
# Also log to a text file.
4289
log_file_path = os.path.join(results_dir, "ann-benchmarks-logs.txt")
43-
if os.path.exists(log_file_path):
44-
open(log_file_path, "w").close()
4590
file_handler = logging.FileHandler(log_file_path)
4691
file_handler.setLevel(logging.INFO)
4792
logger.addHandler(file_handler)
@@ -102,6 +147,17 @@ def execute_commands(ssh, commands):
102147
SecurityGroupIds=security_group_ids,
103148
MinCount=1,
104149
MaxCount=1,
150+
BlockDeviceMappings=[
151+
{
152+
"DeviceName": "/dev/xvda",
153+
"Ebs": {
154+
# Size in GiB.
155+
"VolumeSize": 30,
156+
# General Purpose SSD (gp3).
157+
"VolumeType": "gp3",
158+
},
159+
}
160+
],
105161
)
106162
instance_id = response["Instances"][0]["InstanceId"]
107163
logger.info(f"Launched EC2 instance with ID: {instance_id}")
@@ -150,6 +206,7 @@ def execute_commands(ssh, commands):
150206
"sudo yum install docker -y",
151207
"sudo service docker start",
152208
"sudo usermod -a -G docker ec2-user",
209+
# Docker will not yet be in the groups:
153210
"groups",
154211
]
155212
execute_commands(ssh, initial_commands)
@@ -159,25 +216,44 @@ def execute_commands(ssh, commands):
159216
ssh.close()
160217
time.sleep(10)
161218
ssh.connect(public_dns, username=username, key_filename=key_path)
162-
logger.info("Reconnected to the instance.")
219+
logger.info("Reconnected to the instance to refresh group membership.")
220+
221+
ann_benchmarks_dir = "/home/ec2-user/ann-benchmarks"
163222

164223
# Run the benchmarks.
165224
post_reconnect_commands = [
225+
# Docker should now be in the groups:
166226
"groups",
167227
"git clone https://github.com/TileDB-Inc/ann-benchmarks.git",
168-
"cd ann-benchmarks && pip3 install -r requirements.txt",
169-
"cd ann-benchmarks && python3 install.py --algorithm tiledb",
170-
"cd ann-benchmarks && python3 run.py --dataset sift-128-euclidean --algorithm tiledb-ivf-flat --force --batch",
171-
"cd ann-benchmarks && sudo chmod -R 777 results/sift-128-euclidean/10/tiledb-ivf-flat-batch",
172-
"cd ann-benchmarks && python3 create_website.py",
228+
f"cd {ann_benchmarks_dir} && pip3 install -r requirements.txt",
173229
]
230+
for installation in installations:
231+
post_reconnect_commands.append(
232+
f"cd {ann_benchmarks_dir} && python3 install.py --algorithm {installation}"
233+
)
234+
for algorithm in algorithms:
235+
post_reconnect_commands += [
236+
f"cd {ann_benchmarks_dir} && python3 run.py --dataset sift-128-euclidean --algorithm {algorithm} --force --batch",
237+
f"cd {ann_benchmarks_dir} && sudo chmod -R 777 results/sift-128-euclidean/10/{algorithm}-batch",
238+
]
239+
post_reconnect_commands.append(
240+
f"cd {ann_benchmarks_dir} && python3 create_website.py"
241+
)
174242
execute_commands(ssh, post_reconnect_commands)
175243

244+
logger.info("Finished running the benchmarks.")
245+
176246
# Download the results.
177247
remote_paths = [
178-
"/home/ec2-user/ann-benchmarks/sift-128-euclidean_10_euclidean-batch.png",
179-
"/home/ec2-user/ann-benchmarks/sift-128-euclidean_10_euclidean-batch.html",
248+
f"{ann_benchmarks_dir}/index.html",
249+
f"{ann_benchmarks_dir}/sift-128-euclidean_10_euclidean-batch.png",
250+
f"{ann_benchmarks_dir}/sift-128-euclidean_10_euclidean-batch.html",
180251
]
252+
for algorithm in algorithms:
253+
remote_paths += [
254+
f"{ann_benchmarks_dir}/{algorithm}-batch.png",
255+
f"{ann_benchmarks_dir}/{algorithm}-batch.html",
256+
]
181257
sftp = ssh.open_sftp()
182258
for remote_path in remote_paths:
183259
local_filename = os.path.basename(remote_path)
@@ -194,11 +270,15 @@ def execute_commands(ssh, commands):
194270
except Exception as e:
195271
logger.error(f"Error occurred: {e}")
196272
if "instance_id" in locals():
197-
logger.info(f"Will terminate instance {instance_id}.")
273+
logger.info(
274+
f"Will terminate instance {instance_id} available at public_dns: {public_dns}."
275+
)
198276
terminate_instance(instance_id)
199277

200278
else:
201-
logger.info(f"Finished, will try to terminate instance {instance_id}.")
279+
logger.info(
280+
f"Finished, will try to terminate instance {instance_id} available at public_dns: {public_dns}."
281+
)
202282
if "instance_id" in locals():
203283
logger.info(f"Will terminate instance {instance_id}.")
204284
terminate_instance(instance_id)

0 commit comments

Comments
 (0)