Skip to content

Commit bede39f

Browse files
authored
Make TMA tests compatible with older CUDA toolchains (#5221)
TMA fences require CUDA toolchain 12.3 or greater, but current gating does not check the CUDA toolchain version. This causes `test_experimental_tma.py` to fail when run with older CUDA toolchains. ## Before With cuda-12.0: ``` 55 failed, 9 passed in 18.11s ``` With cuda-12.4: ``` 64 passed in 11.99s ``` ## After With cuda-12.0: ``` 9 passed, 55 skipped in 4.26s ``` With cuda-12.4: ``` 64 passed in 11.96s ```
1 parent e558838 commit bede39f

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

python/test/unit/hopper/test_experimental_tma.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton
55
import triton.language as tl
66
from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor)
7-
from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma
7+
from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg
88

99
from typing import Optional
1010

@@ -29,9 +29,11 @@ def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper):
2929
tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})
3030

3131

32-
@requires_tma
3332
@pytest.mark.parametrize("byval_tma", [True, False])
3433
def test_experimetal_descriptor_load(byval_tma):
34+
if not supports_tma(byval_tma):
35+
pytest.skip(tma_skip_msg(byval_tma))
36+
3537
device = "cuda"
3638
SIZE = 128
3739

@@ -82,11 +84,13 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
8284
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
8385

8486

85-
@requires_tma
8687
@pytest.mark.parametrize("num_stages", [1, 4])
8788
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)])
8889
@pytest.mark.parametrize("byval_tma", [True, False])
8990
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma):
91+
if not supports_tma(byval_tma):
92+
pytest.skip(tma_skip_msg(byval_tma))
93+
9094
device = "cuda"
9195
M, N, K = 8192, 8192, 1024
9296
torch.manual_seed(42)

python/triton/_internal_testing.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import triton
66
import triton.language as tl
7+
from triton.backends.nvidia.compiler import _path_to_binary
78
import pytest
89

910
from numpy.random import RandomState
@@ -140,8 +141,19 @@ def to_numpy(x):
140141
raise ValueError(f"Not a triton-compatible tensor: {x}")
141142

142143

143-
def supports_tma():
144-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
144+
def supports_tma(byval_only=False):
145+
_, cuda_version = _path_to_binary("ptxas")
146+
min_cuda_version = (12, 0) if byval_only else (12, 3)
147+
cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
148+
assert len(cuda_version_tuple) == 2, cuda_version_tuple
149+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
150+
151+
152+
def tma_skip_msg(byval_only=False):
153+
if byval_only:
154+
return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)"
155+
else:
156+
return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)"
145157

146158

147-
requires_tma = pytest.mark.skipif(not supports_tma(), reason="Requires TMA support (NVIDIA Hopper or higher)")
159+
requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())

0 commit comments

Comments
 (0)