diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 71e7568a9..5e53eaf5f 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -6,6 +6,12 @@ from typing import Optional import torch +import sys + +if (sys.platform == "win32"): + rocminfo = "hipinfo" +else: + rocminfo = "rocminfo" @dataclasses.dataclass(frozen=True) @@ -83,7 +89,7 @@ def get_rocm_gpu_arch() -> str: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) + result = subprocess.run([rocminfo], capture_output=True, text=True) match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) if match: return "gfx" + match.group(1) @@ -102,15 +108,18 @@ def get_rocm_gpu_arch() -> str: return "unknown" +# Wavefront size (or warp size) in GPU computing is the number of threads that execute +# together in lockstep on a GPU core, typically 32 or 64, depending on the architecture +# (e.g., Nvidia is 32, older AMD GCN was 64, newer AMD RDNA can be 32 or 64). def get_rocm_warpsize() -> int: """Get ROCm warp size.""" logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) + result = subprocess.run([rocminfo], capture_output=True, text=True) + match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE) if match: - return int(match.group(1)) + return int(match.group(2)) else: # default to 64 to be safe return 64