Skip to content

Commit 6958807

Browse files
authored
[AMD] Enable mixed precision matmul test (#5177)
This commit enables mixed precision matmul test for AMD backend. For FP8 E4M3, we test `fp8e4m3fnuz` given that's natively supported on MI300 series.
1 parent 9aa114a commit 6958807

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

python/test/regression/test_cast_matmul.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,22 @@
1111

1212
import triton
1313
import triton.language as tl
14+
from triton._internal_testing import is_hip_mi300, is_cuda
1415

1516
input_dtypes = ["float16", "float32", "float64"]
16-
if triton.runtime.driver.active.get_current_target().backend == "cuda":
17+
if is_cuda():
1718
input_dtypes += ["int8", "float8_e5m2"]
1819
cc = torch.cuda.get_device_capability(0)
1920
if cc >= (8, 9):
2021
input_dtypes += ["float8_e4m3fn"]
22+
elif is_hip_mi300():
23+
input_dtypes += [
24+
"int8",
25+
"float8_e5m2",
26+
# natively supported on mi300 (see CDNA3 ISA, section 7.2)
27+
"float8_e4m3fnuz",
28+
]
29+
2130
out_dtypes = ["float16", "float32"]
2231

2332

@@ -85,7 +94,7 @@ def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype):
8594
def init_tensor(dtype, shape):
8695
if dtype == torch.int8:
8796
return torch.randint(0, 2, shape, device=device, dtype=dtype)
88-
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
97+
elif dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2):
8998
return torch.randn(shape, device=device, dtype=torch.float16).to(dtype)
9099
else:
91100
return torch.randn(shape, device=device, dtype=dtype)

0 commit comments

Comments
 (0)