Skip to content

Commit 22dc2c9

Browse files
authored
Add script to create an EC2 instance and run ann-benchmarks (#412)
1 parent 0c024db commit 22dc2c9

File tree

3 files changed

+208
-0
lines changed

3 files changed

+208
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
[project.optional-dependencies]
2828
test = ["nbmake", "pytest<8.0.0", "pytest-xdist"]
2929
formatting = ["pre-commit"]
30+
benchmarks = ["boto3", "paramiko"]
3031

3132
[project.urls]
3233
homepage = "https://tiledb.com"

src/benchmarks/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
results/

src/benchmarks/ann-benchmarks.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Used to run ann-benchmarks on an EC2 instance and download the results.
2+
#
3+
# To run:
4+
# - pip install ".[benchmarks]"
5+
# - Set up your AWS credentials locally. You can set them in `~/.aws/credentials` to be picked up automatically.
6+
# - Fill in the following details. You can create these in the EC2 console.
7+
# 1. key_name: Your EC2 key pair name.
8+
# 2. key_path: The to your local private key file.
9+
# - Make sure to `chmod 400 /path/to/key.pem` after download.
10+
# - python src/benchmarks/ann-benchmarks.py
11+
12+
import logging
13+
import os
14+
import socket
15+
import time
16+
17+
import boto3
18+
import paramiko
19+
20+
# You must fill these in before running the script:
21+
key_name = "key_name"
22+
key_path = "/path/to/key.pem"
23+
24+
if not os.path.exists(key_path):
25+
raise FileNotFoundError(
26+
f"Key file not found at {key_path}. Please set the correct path before running."
27+
)
28+
29+
# You do not need to change these.
30+
security_group_ids = ["sg-04258b401ce76d246"]
31+
# 64 vCPU, 512 GiB, EBS-Only.
32+
instance_type = "r6i.16xlarge"
33+
# Amazon Linux 2023 AMI 2023.4.20240528.0 x86_64 HVM kernel-6.1 - 64 bit (x86) - uefi-preferred.
34+
ami_id = "ami-09e647bf7a368e505"
35+
username = "ec2-user"
36+
37+
# Configure logging
38+
logging.basicConfig(level=logging.INFO)
39+
logger = logging.getLogger()
40+
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
41+
os.makedirs(results_dir, exist_ok=True)
42+
log_file_path = os.path.join(results_dir, "ann-benchmarks-logs.txt")
43+
if os.path.exists(log_file_path):
44+
open(log_file_path, "w").close()
45+
file_handler = logging.FileHandler(log_file_path)
46+
file_handler.setLevel(logging.INFO)
47+
logger.addHandler(file_handler)
48+
49+
# Create an EC2 client
50+
ec2 = boto3.client("ec2")
51+
52+
53+
def terminate_instance(instance_id):
54+
logger.info(f"Terminating instance {instance_id}...")
55+
ec2.terminate_instances(InstanceIds=[instance_id])
56+
logger.info(f"Instance {instance_id} terminated.")
57+
58+
59+
def check_ssh_ready(public_dns, key_filename):
60+
"""Poll until SSH is ready"""
61+
timeout = 60 * 2
62+
logger.info(f"Will poll for {timeout} seconds until SSH is ready.")
63+
ssh = paramiko.SSHClient()
64+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
65+
end_time = time.time() + timeout
66+
while time.time() < end_time:
67+
try:
68+
ssh.connect(
69+
public_dns, username=username, key_filename=key_filename, timeout=20
70+
)
71+
ssh.close()
72+
logger.info("SSH is ready.")
73+
return True
74+
except Exception as e:
75+
logger.error(f"Waiting for SSH: {e}")
76+
time.sleep(15)
77+
return False
78+
79+
80+
def execute_commands(ssh, commands):
81+
"""Execute a list of commands on the SSH connection and stream the output"""
82+
for command in commands:
83+
logger.info(f"Executing command: {command}")
84+
stdin, stdout, stderr = ssh.exec_command(command)
85+
86+
# Stream stdout
87+
for line in iter(stdout.readline, ""):
88+
logger.info(line.strip())
89+
90+
# Stream stderr
91+
for line in iter(stderr.readline, ""):
92+
logger.error(line.strip())
93+
94+
95+
try:
96+
# Launch an EC2 instance
97+
logger.info("Launching EC2 instance...")
98+
response = ec2.run_instances(
99+
ImageId=ami_id,
100+
InstanceType=instance_type,
101+
KeyName=key_name,
102+
SecurityGroupIds=security_group_ids,
103+
MinCount=1,
104+
MaxCount=1,
105+
)
106+
instance_id = response["Instances"][0]["InstanceId"]
107+
logger.info(f"Launched EC2 instance with ID: {instance_id}")
108+
109+
# Wait for the instance to be in a running state.
110+
logger.info("Waiting for instance to enter running state...")
111+
waiter = ec2.get_waiter("instance_running")
112+
waiter.wait(InstanceIds=[instance_id])
113+
114+
# Get the public DNS name of the instance.
115+
instance_description = ec2.describe_instances(InstanceIds=[instance_id])
116+
public_dns = instance_description["Reservations"][0]["Instances"][0][
117+
"PublicDnsName"
118+
]
119+
logger.info(f"Public DNS of the instance: {public_dns}")
120+
121+
# Tag the instance.
122+
instance_name = f"vector-search-ann-benchmarks-{socket.gethostname()}"
123+
logger.info(f"Will name the instance: {instance_name}")
124+
ec2.create_tags(
125+
Resources=[instance_id],
126+
Tags=[
127+
{
128+
"Key": "Name",
129+
"Value": instance_name,
130+
},
131+
],
132+
)
133+
134+
# Wait for SSH to be ready
135+
if not check_ssh_ready(public_dns=public_dns, key_filename=key_path):
136+
raise RuntimeError("SSH did not become ready in time")
137+
138+
# Connect to the instance using paramiko
139+
logger.info("Connecting to the instance via SSH...")
140+
ssh = paramiko.SSHClient()
141+
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
142+
ssh.connect(public_dns, username=username, key_filename=key_path)
143+
logger.info("Connected to the instance.")
144+
145+
# Initial setup commands
146+
initial_commands = [
147+
"sudo yum update -y",
148+
"sudo yum install git -y",
149+
"sudo yum install python3.9-pip -y",
150+
"sudo yum install docker -y",
151+
"sudo service docker start",
152+
"sudo usermod -a -G docker ec2-user",
153+
"groups",
154+
]
155+
execute_commands(ssh, initial_commands)
156+
157+
# Reconnect to the instance to refresh group membership.
158+
logger.info("Reconnecting to the instance to refresh group membership...")
159+
ssh.close()
160+
time.sleep(10)
161+
ssh.connect(public_dns, username=username, key_filename=key_path)
162+
logger.info("Reconnected to the instance.")
163+
164+
# Run the benchmarks.
165+
post_reconnect_commands = [
166+
"groups",
167+
"git clone https://github.com/TileDB-Inc/ann-benchmarks.git",
168+
"cd ann-benchmarks && pip3 install -r requirements.txt",
169+
"cd ann-benchmarks && python3 install.py --algorithm tiledb",
170+
"cd ann-benchmarks && python3 run.py --dataset sift-128-euclidean --algorithm tiledb-ivf-flat --force --batch",
171+
"cd ann-benchmarks && sudo chmod -R 777 results/sift-128-euclidean/10/tiledb-ivf-flat-batch",
172+
"cd ann-benchmarks && python3 create_website.py",
173+
]
174+
execute_commands(ssh, post_reconnect_commands)
175+
176+
# Download the results.
177+
remote_paths = [
178+
"/home/ec2-user/ann-benchmarks/sift-128-euclidean_10_euclidean-batch.png",
179+
"/home/ec2-user/ann-benchmarks/sift-128-euclidean_10_euclidean-batch.html",
180+
]
181+
sftp = ssh.open_sftp()
182+
for remote_path in remote_paths:
183+
local_filename = os.path.basename(remote_path)
184+
local_path = os.path.join(results_dir, local_filename)
185+
logger.info(f"Downloading {remote_path} to {local_path}...")
186+
sftp.get(remote_path, local_path)
187+
logger.info(f"File downloaded to {local_path}.")
188+
logger.info("File downloading complete, closing the SFTP connection.")
189+
sftp.close()
190+
191+
logger.info("Benchmarking complete, closing the SSH connection.")
192+
ssh.close()
193+
194+
except Exception as e:
195+
logger.error(f"Error occurred: {e}")
196+
if "instance_id" in locals():
197+
logger.info(f"Will terminate instance {instance_id}.")
198+
terminate_instance(instance_id)
199+
200+
else:
201+
logger.info(f"Finished, will try to terminate instance {instance_id}.")
202+
if "instance_id" in locals():
203+
logger.info(f"Will terminate instance {instance_id}.")
204+
terminate_instance(instance_id)
205+
206+
logger.info("Done.")

0 commit comments

Comments
 (0)