Skip to content

Commit 1d140f8

Browse files
authored
Add fallback to BF16 support check (#7754)
When DeepSpeed is installed without `--no-build-isolation`, `torch_info` contains placeholder values ('0.0') causing BF16 tests to incorrectly skip. This PR adds runtime fallback detection for torch, NCCL, and CUDA versions when stored values are invalid. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent bfb66c6 commit 1d140f8

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

tests/unit/util.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,41 @@ def skip_on_cuda(valid_cuda):
3838

3939
def bf16_required_version_check(accelerator_check=True):
4040
split_version = lambda x: map(int, x.split('.')[:2])
41-
TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version'])
42-
NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version'])
43-
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
41+
42+
# torch_info may have stale/zero values if installed without --no-build-isolation
43+
# In that case, fall back to runtime detection
44+
if torch_info['version'] == '0.0':
45+
# Use runtime torch version
46+
TORCH_MAJOR, TORCH_MINOR = split_version(torch.__version__)
47+
else:
48+
TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version'])
49+
50+
if torch_info['nccl_version'] == '0.0':
51+
# Use runtime NCCL version if available
52+
if torch.cuda.is_available(): #ignore-cuda
53+
try:
54+
nccl_ver = torch.cuda.nccl.version() #ignore-cuda
55+
NCCL_MAJOR, NCCL_MINOR = nccl_ver[0], nccl_ver[1]
56+
except (AttributeError, RuntimeError):
57+
NCCL_MAJOR, NCCL_MINOR = 0, 0
58+
else:
59+
# No CUDA means no NCCL; NPU/HPU/XPU have separate checks below
60+
NCCL_MAJOR, NCCL_MINOR = 0, 0
61+
else:
62+
NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version'])
63+
64+
if torch_info['cuda_version'] == '0.0':
65+
# Use runtime CUDA version
66+
if torch.cuda.is_available(): #ignore-cuda
67+
cuda_ver = torch.version.cuda
68+
if cuda_ver:
69+
CUDA_MAJOR, CUDA_MINOR = split_version(cuda_ver)
70+
else:
71+
CUDA_MAJOR, CUDA_MINOR = 0, 0
72+
else:
73+
CUDA_MAJOR, CUDA_MINOR = 0, 0
74+
else:
75+
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
4476

4577
# Sometimes bf16 tests are runnable even if not natively supported by accelerator
4678
if accelerator_check:

0 commit comments

Comments
 (0)