Skip to content

Commit c32e906

Browse files
Check the right device per shape
1 parent 42fc627 commit c32e906

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

playbooks/roles/healthchecks/files/check_gpu_setup.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
from xid_checker import XidChecker
1111
import platform
1212
import os
13-
import sys
13+
import requests
14+
15+
def get_metadata():
16+
""" Make a request to metadata endpoint """
17+
headers = { 'Authorization' : 'Bearer Oracle' }
18+
metadata_url = "http://169.254.169.254/opc/"
19+
metadata_ver = "2"
20+
request_url = metadata_url + "v" + metadata_ver + "/instance/"
21+
return requests.get(request_url, headers=headers).json()
1422

1523
def is_user_root():
1624
# Check if the user is root
@@ -189,8 +197,14 @@ def check_row_remap_errors():
189197

190198
def check_rdma_link_status():
191199
status = True
192-
devices = ["mlx5_0", "mlx5_1", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_7", "mlx5_8", "mlx5_9", "mlx5_10", "mlx5_12", "mlx5_13", "mlx5_14", "mlx5_15", "mlx5_16", "mlx5_17"]
193-
200+
metadata=get_metadata()
201+
shape=metadata['shape']
202+
if shape == "BM.GPU.H100.8":
203+
devices = ["mlx5_0", "mlx5_1", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_7", "mlx5_8", "mlx5_9", "mlx5_10", "mlx5_12", "mlx5_13", "mlx5_14", "mlx5_15", "mlx5_16", "mlx5_17"]
204+
elif shape == "BM.GPU.B4.8" or shape == "BM.GPU.A100-v2.8":
205+
devices = ["mlx5_1", "mlx5_2", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_7", "mlx5_8", "mlx5_9", "mlx5_10", "mlx5_11", "mlx5_12", "mlx5_14", "mlx5_15", "mlx5_16", "mlx5_17"]
206+
elif shape == "BM.GPU.4.8":
207+
devices = ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3", "mlx5_6", "mlx5_7", "mlx5_8", "mlx5_9", "mlx5_10", "mlx5_11", "mlx5_12", "mlx5_13", "mlx5_14", "mlx5_15", "mlx5_16", "mlx5_17"]
194208
link_issues = []
195209
for device in devices:
196210
# Run the mlxlink command
@@ -501,7 +515,7 @@ def slurm_reason(message):
501515
slurm_reason("Missing GPU Error")
502516

503517
datetime_str = datetime.now().strftime('%Y-%m-%d-%H%M%S')
504-
logger.info(f"Finished H100 setup check at: {datetime_str}")
518+
logger.info(f"Finished GPU host setup check at: {datetime_str}")
505519

506520
if slurm_error_count > 0 and args.slurm:
507521
print("Healthcheck:: "+slurm_drain_reason[:-1])

0 commit comments

Comments
 (0)