@@ -610,6 +610,29 @@ def get_gpu_info():
610610 except Exception as err :
611611 _log .debug ("Exception was raised when running nvidia-smi: %s" , err )
612612 _log .info ("No NVIDIA GPUs detected" )
613+
614+ try :
615+ cmd = "rocm-smi --showdriverversion --csv"
616+ _log .debug ("Trying to determine AMD GPU driver on Linux via cmd '%s'" , cmd )
617+ out , ec = run_cmd (cmd , force_in_dry_run = True , trace = False , stream_output = False )
618+ if ec == 0 :
619+ amd_driver = out .strip ().split ('\n ' )[1 ].split (',' )[1 ]
620+
621+ cmd = "rocm-smi --showproductname --csv"
622+ _log .debug ("Trying to determine AMD GPU info on Linux via cmd '%s'" , cmd )
623+ out , ec = run_cmd (cmd , force_in_dry_run = True , trace = False , stream_output = False )
624+ if ec == 0 :
625+ for line in out .strip ().split ('\n ' )[1 :]:
626+ amd_card_model = line .split (',' )[2 ]
627+ amd_gpu = ', ' .join ([amd_card_model , amd_driver ])
628+ amd_gpu_info = gpu_info .setdefault ('AMD' , {})
629+ amd_gpu_info .setdefault (amd_gpu , 0 )
630+ amd_gpu_info [amd_gpu ] += 1
631+ else :
632+ _log .debug ("None zero exit (%s) from rocm-smi: %s" , ec , out )
633+ except Exception as err :
634+ _log .debug ("Exception was raised when running rocm-smi: %s" , err )
635+ _log .info ("No AMD GPUs detected" )
613636 else :
614637 _log .info ("Only know how to get GPU info on Linux, assuming no GPUs are present" )
615638
0 commit comments