Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 56 additions & 153 deletions oci/h100_health_checks/check_h100_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import re
import argparse
from datetime import datetime
from shared_logging import logger
from common_logger import CommonLogger
from common_logger import runWithDummyValues
from gpu_bw_test import BandwidthTest
from ecc_test import ECCTest
from gpu_remap_test import GPURemapTest
from rttcc_test import RTTCCTest
from rdma_link_flapping import LinkFlappingTest
from xid_checker import XidChecker
import platform
Expand All @@ -21,6 +25,8 @@ def get_metadata():
return requests.get(request_url, headers=headers).json()

def is_user_root():
if bool(runWithDummyValues):
return True
# Check if the user is root
if os.geteuid() != 0:
logger.debug("User is root")
Expand Down Expand Up @@ -75,126 +81,6 @@ def get_oca_version():
# Return the version
return version

def check_rttcc_status():
link_status = []
devices = ["mlx5_0", "mlx5_1", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_7", "mlx5_8", "mlx5_9", "mlx5_10", "mlx5_12", "mlx5_13", "mlx5_14", "mlx5_15", "mlx5_16", "mlx5_17"]
status = "disabled"
status_dict = {"devices": {}}
for device in devices:
if not is_user_root():
command = ['sudo', 'mlxreg', '-d', device, '-y', '--get', '--reg_name=PPCC', '--indexes=local_port=1,pnat=0,lp_msb=0,algo_slot=0,algo_param_index=0']
else:
command = ['mlxreg', '-d', device, '-y', '--set', 'cmd_type=3', '--reg_name=PPCC', '--indexes=local_port=1,pnat=0,lp_msb=0,algo_slot=0,algo_param_index=0']
result = subprocess.run(command, stdout=subprocess.PIPE)
output = result.stdout.decode('utf-8')
filtered_output = [line for line in output.split('\n') if line.startswith('value')]
for line in filtered_output:
logger.debug(line)
if "0x00000001" in line:
status_dict["devices"][device] = "enabled"

for device in status_dict["devices"]:
if status_dict["devices"][device] == "enabled":
logger.warning(f"RTTCC enabled on {device}")
status = "enabled"
link_status.append(f"RTTCC enabled on: {device}")
else:
logger.info(f"RTTCC status for {device}: disabled")
if status == "disabled":
logger.info(f"RTTCC disabled check: Passed")
else:
logger.error(f"RTTCC disabled check: Failed")

return link_status

def check_ecc_errors():
ecc_issues = []
try:
# Run the nvidia-smi -q command
result = subprocess.run(['nvidia-smi', '-q'], stdout=subprocess.PIPE)
except FileNotFoundError:
logger.warning("Skipping SRAM/DRAM ECC Test: nvidia-smi command not found")
return []

# Decode the output from bytes to string
output = result.stdout.decode('utf-8')

# Find the lines containing "SRAM Correctable" and "DRAM Correctable"
sram_matches = re.findall(r'SRAM Uncorrectable\s+:\s+(\d+)', output)
if len(sram_matches)==0:
sram_matches = re.findall(r'SRAM Uncorrectable Parity\s+:\s+(\d+)', output)
dram_matches = re.findall(r'DRAM Uncorrectable\s+:\s+(\d+)', output)
gpu_matches = re.findall(r'\nGPU\s+(.*)\n', output)
vol_sram_line = sram_matches[0::2]
vol_dram_line = dram_matches[0::2]
agg_sram_line = sram_matches[1::2]
agg_dram_line = dram_matches[1::2]

for i, gpu in enumerate(gpu_matches):
logger.debug(f"GPU: {gpu}")
if vol_sram_line[i] != "0":
logger.debug(f"Volatile SRAM Uncorrectable: {vol_sram_line[i]}")
ecc_issues.append(f"{gpu_matches[i]} - Volatile SRAM Uncorrectable: {vol_sram_line[i]}")
if vol_dram_line[i] != "0":
logger.debug(f"Volatile DRAM Uncorrectable: {vol_dram_line[i]}")
ecc_issues.append(f"{gpu_matches[i]} - Volatile DRAM Uncorrectable: {vol_dram_line[i]}")
if agg_sram_line[i] != "0":
logger.debug(f"Aggregate SRAM Uncorrectable: {agg_sram_line[i]}")
ecc_issues.append(f"{gpu_matches[i]} - Aggregate SRAM Uncorrectable: {agg_sram_line[i]}")
if agg_dram_line[i] != "0":
logger.debug(f"Aggregate DRAM Uncorrectable: {agg_dram_line[i]}")
ecc_issues.append(f"{gpu_matches[i]} - Aggregate DRAM Uncorrectable: {agg_dram_line[i]}")


# Check if there are ecc_issues
if len(ecc_issues) == 0:
logger.info("GPU ECC Test: Passed")
else:
logger.warning("GPU ECC Test: Failed")

return ecc_issues

def check_row_remap_errors():
remap_issues = []
try:
# Run the nvidia-smi -q command
result = subprocess.run(['nvidia-smi', '--query-remapped-rows=remapped_rows.pending,remapped_rows.failure,remapped_rows.uncorrectable', '--format=csv,noheader'], stdout=subprocess.PIPE)

if result.returncode != 0:
logger.debug(f"Check row remap command exited with error code: {result.returncode}")

except FileNotFoundError:
logger.warning("Skipping Row Remap Test: nvidia-smi command not found")
return []

# Decode the output from bytes to string
output = result.stdout.decode('utf-8')
logger.debug("Output: {}".format(output))
for i, line in enumerate(output.split('\n')):
if line == "":
continue
tmp_data = line.split(",")
tmp_data = [x.strip() for x in tmp_data]
if tmp_data[0] != "0":
logger.debug(f"GPU: {i} - Row Remap Pending: {tmp_data[0]}")
remap_issues.append(f"GPU: {i} Row Remap Pending: {tmp_data[0]}")
if tmp_data[1] != "0":
logger.debug(f"GPU: {i} - Row Remap Failure: {tmp_data[1]}")
#remap_issues.append(f"GPU: {i} Row Remap Failure: {tmp_data[1]}")
if tmp_data[2] != "0":
logger.debug(f"GPU: {i} - Row Remap Uncorrectable: {tmp_data[2]}")
if int(tmp_data[2]) > 512:
remap_issues.append(f"GPU: {i} - Row Remap Uncorrectable >512: {tmp_data[2]}")
else:
remap_issues.append(f"GPU: {i} - Row Remap Uncorrectable <512: {tmp_data[2]}")# Check if there are ecc_issues

if len(remap_issues) == 0:
logger.info("GPU Remap Test: Passed")
else:
logger.warning("GPU Remap Test: Failed")

return remap_issues

def check_rdma_link_status():
status = True
metadata=get_metadata()
Expand Down Expand Up @@ -260,6 +146,8 @@ def check_rdma_link_status():
return link_issues

def get_host_serial():
if (runWithDummyValues):
return "2349XLG02D"
# Run the shell command
if not is_user_root():
result = subprocess.run(['sudo', 'dmidecode', '-s', 'system-serial-number'], stdout=subprocess.PIPE)
Expand Down Expand Up @@ -363,31 +251,48 @@ def slurm_reason(message):
parser.add_argument('-slurm','--slurm', dest='slurm', action='store_true', default=False, help='Add a Slurm message')
args = parser.parse_args()

logger = CommonLogger.getLogger("h100", None, None)
logger.setLevel(args.log_level)

# Summarize the results
try:
host_serial = get_host_serial()
logger.setHostSerial(host_serial)
except Exception as e:
logger.warning(f"Failed to get host serial number with error: {e}")
host_serial = "Unknown"

logger.info(f"--------- Starting Host setup check for {host_serial} ---------")

datetime_str = datetime.now().strftime('%Y-%m-%d-%H%M%S')
logger.info(f"Started GPU host setup check at: {datetime_str}")
try:
oca_version = get_oca_version()
except Exception as e:
logger.warning(f"Failed to get Oracle Cloud Agent version with error: {e}")
oca_version = "Unknown"

rttc = None
try:
rttcc_issues = check_rttcc_status()
rttc = RTTCCTest(is_user_root())
rttcc_issues = rttc.check_rttcc_status()
except Exception as e:
logger.warning(f"Failed to check RTTCC status with error: {e}")
rttcc_issues = []

# Check for ECC errors
ecc = None
try:
ecc_issues = check_ecc_errors()
ecc = ECCTest()
ecc_issues = ecc.check_ecc_errors()
except Exception as e:
logger.warning(f"Failed to check ECC errors with error: {e}")
ecc_issues = []

# Check for row remap errors
gpurremap = GPURemapTest()
try:
remap_results = check_row_remap_errors()
remap_results = gpurremap.check_row_remap_errors()
except Exception as e:
logger.warning(f"Failed to check row remap errors with error: {e}")
remap_results = []
Expand All @@ -409,6 +314,7 @@ def slurm_reason(message):
lft_issues = {"failures": [], "link_down": []}

# Check for GPU Xid errors
xc = None
try:
xc = XidChecker()
xid_results = xc.check_gpu_xid()
Expand All @@ -417,9 +323,10 @@ def slurm_reason(message):
xid_results = {"status": "None", "results": {}}

# Check GPU bandwidth
bwt = None
bwt_results = None
try:
if args.bw_test == True or args.run_all == True:
if bool(runWithDummyValues) or args.bw_test == True or args.run_all == True:
if args.bw_test_exe:
bwt = BandwidthTest(bw_test_exe=args.bw_test_exe)
else:
Expand Down Expand Up @@ -454,64 +361,60 @@ def slurm_reason(message):
slurm_drain_reason = ""
slurm_error_count = 0

logger.set("h100", host_serial, None)
logger.info(f"--------- Summary of Host setup check for {host_serial} ---------")
if oca_version < "1.39.0":

if oca_version is None or oca_version == "Unknown" or oca_version < "1.39.0":
logger.error(f"Oracle Cloud Agent: {oca_version} needs to be updated to 1.39.0 or higher")
slurm_reason("OCA version Error")

if len(rttcc_issues) > 0:
logger.error(f"RTTCC issues: {rttcc_issues}")
rttc.logResults(rttcc_issues)
slurm_reason("RTTCC Error")

if len(ecc_issues) > 0:
ecc_error=False
for issue in ecc_issues:
if "Skipped" in issue:
logger.warning(f"{host_serial} - {issue}")
else:
if "Aggregate" in issue:
logger.warning(f"{host_serial} - ECC issues: {issue}")
else:
logger.error(f"{host_serial} - ECC issues: {issue}")
ecc_error=True
ecc_error=ecc.logResults(ecc_issues, host_serial)
if ecc_error:
slurm_reason("ECC Error")

if len(remap_results) > 0:
remap_error=False
for issue in remap_results:
if "<512" in issue:
logger.warning(f"{host_serial} - {issue}")
else:
logger.error(f"{host_serial} - {issue}")
remap_error=True
remap_error=gpurremap.logResults(remap_results)
if remap_error:
slurm_reason("Remap Error")

if xid_results["status"] == "Failed":
xc.logResults(host_serial, xid_results)
for xid in xid_results["results"]:
for pci in xid_results["results"][xid]["results"]:
logger.error(f"{host_serial} - GPU Xid {xid} device: {pci}, {xid_results['results'][xid]['description']}")
slurm_reason("XID Error")

if len(rdma_link_issues) > 0:
for issue in rdma_link_issues:
logger.error(f"{host_serial} - RDMA link issues: {issue}")
logger.error(f"{issue}")
slurm_reason("RDMA Link Error")

if len(lft_issues["failures"]) > 0 or len(lft_issues["link_down"]) > 0:
lft.logResults(lft_issues, host_serial);
if len(lft_issues["failures"]) > 0:
for issue in lft_issues["failures"]:
logger.error(f"{host_serial} - RDMA link flapping issues: {issue}")
slurm_reason("RDMA Link Flapping Error")
if len(lft_issues["link_down"]) > 0:
for issue in lft_issues["link_down"]:
logger.error(f"{host_serial} - RDMA link down issues: {issue}")
slurm_reason("RDMA Link Down Error")

if bwt_results != None:
bwt.logResults(host_serial, bwt_results)
if bwt_results["status"] == "Failed":
for issue in bwt_results["issues"]:
logger.error(f"{host_serial} - GPU bandwidth issues: {issue}")
slurm_reason("GPU Bwt Error")
for device in bwt_results["devices"]:
for issue in bwt_results["devices"][device]:
slurm_reason("GPU Bwt Error")

if bus_results:
logger.error(f"{host_serial} - Bus issues: {bus_results}")
logger.error2("Bus issues", f"{bus_results}")
slurm_reason("GPU Bus Error")

if gpu_results:
logger.error(f"{host_serial} - Missing GPU(s): {gpu_results}")
logger.error("Missing GPU(s)", f"{gpu_results}")
slurm_reason("Missing GPU Error")

datetime_str = datetime.now().strftime('%Y-%m-%d-%H%M%S')
Expand Down
Loading