Skip to content

Commit 2bd17c5

Browse files
authored
Update ann-benchmarks.py to connect to running instance (#500)
1 parent f1405d4 commit 2bd17c5

File tree

1 file changed

+64
-48
lines changed

1 file changed

+64
-48
lines changed

src/benchmarks/ann-benchmarks.py

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@
6464
f"Key file not found at {key_path}. Please set the correct path before running."
6565
)
6666

67+
# If you want to connect to a running instance, set these. Note that we could fetch instance_id from
68+
# the public_dns (via ec2.describe_instances()), but I had IAM permissions issues, so we manually set.
69+
# Example: "ec2-9-38-120-343.eu-central-1.compute.amazonaws.com"
70+
connect_to_running_instance_public_dns = ""
71+
# Example: "i-0790c23df32093234"
72+
connect_to_running_instance_id = ""
73+
6774
# You do not need to change these.
6875
security_group_ids = ["sg-04258b401ce76d246"]
6976
# 64 vCPU, 512 GiB, EBS-Only.
@@ -137,54 +144,62 @@ def execute_commands(ssh, commands):
137144

138145

139146
try:
140-
# Launch an EC2 instance
141-
logger.info("Launching EC2 instance...")
142-
response = ec2.run_instances(
143-
ImageId=ami_id,
144-
InstanceType=instance_type,
145-
KeyName=key_name,
146-
SecurityGroupIds=security_group_ids,
147-
MinCount=1,
148-
MaxCount=1,
149-
BlockDeviceMappings=[
150-
{
151-
"DeviceName": "/dev/xvda",
152-
"Ebs": {
153-
# Size in GiB.
154-
"VolumeSize": 30,
155-
# General Purpose SSD (gp3).
156-
"VolumeType": "gp3",
147+
if connect_to_running_instance_public_dns and connect_to_running_instance_id:
148+
# Connect to an existng EC2 instance.
149+
public_dns = connect_to_running_instance_public_dns
150+
instance_id = connect_to_running_instance_id
151+
logger.info(
152+
f"Will connect to running instance at public_dns: {public_dns} and instance_id: {instance_id}"
153+
)
154+
else:
155+
# Launch an EC2 instance.
156+
logger.info("Launching EC2 instance...")
157+
response = ec2.run_instances(
158+
ImageId=ami_id,
159+
InstanceType=instance_type,
160+
KeyName=key_name,
161+
SecurityGroupIds=security_group_ids,
162+
MinCount=1,
163+
MaxCount=1,
164+
BlockDeviceMappings=[
165+
{
166+
"DeviceName": "/dev/xvda",
167+
"Ebs": {
168+
# Size in GiB.
169+
"VolumeSize": 30,
170+
# General Purpose SSD (gp3).
171+
"VolumeType": "gp3",
172+
},
173+
}
174+
],
175+
)
176+
instance_id = response["Instances"][0]["InstanceId"]
177+
logger.info(f"Launched EC2 instance with ID: {instance_id}")
178+
179+
# Wait for the instance to be in a running state.
180+
logger.info("Waiting for instance to enter running state...")
181+
waiter = ec2.get_waiter("instance_running")
182+
waiter.wait(InstanceIds=[instance_id])
183+
184+
# Get the public DNS name of the instance.
185+
instance_description = ec2.describe_instances(InstanceIds=[instance_id])
186+
public_dns = instance_description["Reservations"][0]["Instances"][0][
187+
"PublicDnsName"
188+
]
189+
logger.info(f"Public DNS of the instance: {public_dns}")
190+
191+
# Tag the instance.
192+
instance_name = f"vector-search-ann-benchmarks-{socket.gethostname()}"
193+
logger.info(f"Will name the instance: {instance_name}")
194+
ec2.create_tags(
195+
Resources=[instance_id],
196+
Tags=[
197+
{
198+
"Key": "Name",
199+
"Value": instance_name,
157200
},
158-
}
159-
],
160-
)
161-
instance_id = response["Instances"][0]["InstanceId"]
162-
logger.info(f"Launched EC2 instance with ID: {instance_id}")
163-
164-
# Wait for the instance to be in a running state.
165-
logger.info("Waiting for instance to enter running state...")
166-
waiter = ec2.get_waiter("instance_running")
167-
waiter.wait(InstanceIds=[instance_id])
168-
169-
# Get the public DNS name of the instance.
170-
instance_description = ec2.describe_instances(InstanceIds=[instance_id])
171-
public_dns = instance_description["Reservations"][0]["Instances"][0][
172-
"PublicDnsName"
173-
]
174-
logger.info(f"Public DNS of the instance: {public_dns}")
175-
176-
# Tag the instance.
177-
instance_name = f"vector-search-ann-benchmarks-{socket.gethostname()}"
178-
logger.info(f"Will name the instance: {instance_name}")
179-
ec2.create_tags(
180-
Resources=[instance_id],
181-
Tags=[
182-
{
183-
"Key": "Name",
184-
"Value": instance_name,
185-
},
186-
],
187-
)
201+
],
202+
)
188203

189204
# Wait for SSH to be ready
190205
if not check_ssh_ready(public_dns=public_dns, key_filename=key_path):
@@ -233,7 +248,8 @@ def execute_commands(ssh, commands):
233248
)
234249
for algorithm in algorithms:
235250
post_reconnect_commands += [
236-
f"cd {ann_benchmarks_dir} && python3 run.py --dataset sift-128-euclidean --algorithm {algorithm} --force --batch",
251+
# NOTE: If you want to force re-running a benchmark even if the results exist, add --force.
252+
f"cd {ann_benchmarks_dir} && python3 run.py --dataset sift-128-euclidean --algorithm {algorithm} --batch",
237253
f"cd {ann_benchmarks_dir} && sudo chmod -R 777 results/sift-128-euclidean/10/{algorithm}-batch",
238254
]
239255
post_reconnect_commands.append(

0 commit comments

Comments
 (0)