Skip to content

Commit af0a9f2

Browse files
authored
[AMD] Enable pipeliner test for scaled_dot (#5068)
This commit enables pipeliner test for scaled dot on the AMD backend. Along the way, unified some target/arch probe utilities into the common `_internal_testing` file.
1 parent 20361eb commit af0a9f2

File tree

5 files changed

+24
-51
lines changed

5 files changed

+24
-51
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,7 @@
77
import triton.language as tl
88
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
99
import traceback
10-
11-
12-
def is_interpreter():
13-
return os.environ.get('TRITON_INTERPRET', '0') == '1'
14-
15-
16-
def is_cuda():
17-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda"
18-
19-
20-
def is_hip():
21-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip"
22-
23-
24-
def is_on_mi300():
25-
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')
10+
from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300
2611

2712

2813
def test_err_undefined_variable():
@@ -367,7 +352,7 @@ def test_fp8_support(dtype):
367352
if cc >= (8, 9):
368353
supported_dtypes.append(tl.float8e4nv)
369354
elif is_hip():
370-
if is_on_mi300():
355+
if is_hip_mi300():
371356
supported_dtypes += [tl.float8e4b8, tl.float8e5b16]
372357
elif is_interpreter():
373358
supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]

python/test/unit/language/test_conversions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
11
# fmt: off
22

33

4-
import os
54
import numpy as np
65
import torch
76
import pytest
87
import triton
98
import triton.language as tl
109

11-
def is_interpreter():
12-
return os.environ.get('TRITON_INTERPRET', '0') == '1'
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300
1311

14-
def is_cuda():
15-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda"
16-
17-
def is_hip():
18-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip"
19-
20-
def is_on_mi300():
21-
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')
2212

2313
def matching_int(dtype):
2414
if dtype.primitive_bitwidth == 8:
@@ -283,7 +273,7 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
283273
def test_typeconvert_upcast(src_dtype, dst_dtype, device):
284274
if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9))
285275
or (src_dtype in ('float8e4nv', 'float8e4b15') and is_hip())
286-
or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()))):
276+
or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_hip_mi300()))):
287277
# If the dtype should error out in the given device, we assert that and return
288278
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
289279
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
@@ -334,7 +324,7 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
334324
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)):
335325
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
336326

337-
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()):
327+
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_hip_mi300()):
338328
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300")
339329

340330
# dtype : (exponent_bits, mantissa_bits, exponent_bias)

python/test/unit/language/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
is_cuda,
3030
is_interpreter,
3131
is_hip,
32+
is_hip_cdna,
3233
is_hip_mi200,
3334
get_arch,
3435
torch_float8_dtypes,
@@ -3338,13 +3339,12 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
33383339
if cc < (8, 9):
33393340
pytest.skip("float8e4nv not supported on CUDA < 8.9")
33403341
if is_hip():
3342+
if not is_hip_cdna():
3343+
pytest.skip("scaled_dot only implemented for HIP CDNA")
33413344
if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]):
33423345
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
33433346
if mma == 16 and K == 64:
33443347
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
3345-
arch = triton.runtime.driver.active.get_current_target().arch
3346-
if "gfx11" in arch or "gfx12" in arch:
3347-
pytest.skip("scaled_dot not yet implemented for gfx11 and gfx12")
33483348

33493349
@triton.jit
33503350
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,

python/test/unit/language/test_pipeliner.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,7 @@
66
import triton.language as tl
77
import triton.tools.experimental_descriptor
88

9-
10-
def is_cuda():
11-
return triton.runtime.driver.active.get_current_target().backend == "cuda"
12-
13-
14-
def is_hopper():
15-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
16-
17-
18-
def is_hip():
19-
return triton.runtime.driver.active.get_current_target().backend == "hip"
20-
21-
22-
def is_hip_mi200():
23-
target = triton.runtime.driver.active.get_current_target()
24-
return target.backend == 'hip' and target.arch == 'gfx90a'
9+
from triton._internal_testing import is_cuda, is_hopper, is_hip_cdna, is_hip_mi200
2510

2611

2712
def check_capabilities():
@@ -229,8 +214,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
229214
@pytest.mark.parametrize("scale", [True, False])
230215
def test_pipeline_matmul(scale, device):
231216
check_capabilities()
232-
if scale and not is_cuda():
233-
pytest.skip("NYI: scale_dot just implemented in CUDA")
217+
if scale and not (is_cuda() or is_hip_cdna()):
218+
pytest.skip("NYI: scale_dot just implemented in CUDA/HIP")
234219
M, N, K = 512, 512, 128
235220
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
236221
NUM_STAGES = 4

python/triton/_internal_testing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def is_cuda():
3636
return False if target is None else target.backend == "cuda"
3737

3838

39+
def is_hopper():
40+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
41+
42+
3943
def is_hip():
4044
target = get_current_target()
4145
return False if target is None else target.backend == "hip"
@@ -46,6 +50,15 @@ def is_hip_mi200():
4650
return target.backend == 'hip' and target.arch == 'gfx90a'
4751

4852

53+
def is_hip_mi300():
54+
target = get_current_target()
55+
return target.backend == 'hip' and target.arch in ('gfx940', 'gfx941', 'gfx942')
56+
57+
58+
def is_hip_cdna():
59+
return is_hip_mi200() or is_hip_mi300()
60+
61+
4962
def get_arch():
5063
target = get_current_target()
5164
return "" if target is None else str(target.arch)

0 commit comments

Comments
 (0)