Skip to content

Commit 5700c14

Browse files
authored
[FRONTEND] Fix and improve minimum dot size checks (#5383)
1. Fix the problem that [m, k, n] but not [m, n, k] is returned on the nvidia backend 2. Check both int8 and float8 3. Add a new compiler error test 4. Fix dtype check in AMD backend
1 parent e3d3851 commit 5700c14

File tree

4 files changed

+51
-3
lines changed

4 files changed

+51
-3
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton.language as tl
88
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
99
import traceback
10-
from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300
10+
from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300, is_hip_mi200
1111

1212

1313
def test_err_undefined_variable():
@@ -379,6 +379,42 @@ def dtype_kernel(dtype: tl.constexpr):
379379
raise assertion_err from e.value
380380

381381

382+
@pytest.mark.parametrize("dtype", [tl.float8e5, tl.int8, tl.float16])
383+
def test_min_dot_size(dtype):
384+
error_msg = "Input shapes should have "
385+
if is_cuda():
386+
if dtype.primitive_bitwidth == 8:
387+
error_msg += "M >= 16, N >= 16 and K >= 32"
388+
else:
389+
error_msg = "M >= 16, N >= 16 and K >= 16"
390+
elif is_hip_mi300():
391+
if dtype.is_int8():
392+
error_msg += "M >= 16, N >= 16 and K >= 16"
393+
else:
394+
error_msg += "M >= 16, N >= 16 and K >= 8"
395+
elif is_hip_mi200():
396+
error_msg += "M >= 16, N >= 16 and K >= 8"
397+
elif is_hip():
398+
error_msg = "M >= 16, N >= 16 and K >= 16"
399+
else:
400+
pytest.skip("Test only supported on CUDA and HIP")
401+
402+
@triton.jit
403+
def dot_kernel(dtype: tl.constexpr):
404+
SIZE: tl.constexpr = 8
405+
a = tl.full((SIZE, SIZE), 0.0, dtype)
406+
b = tl.full((SIZE, SIZE), 0.0, dtype)
407+
tl.dot(a, b)
408+
409+
with pytest.raises(CompilationError) as e:
410+
triton.compile(
411+
triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))
412+
try:
413+
assert (error_msg in str(e.value.__cause__))
414+
except AssertionError as assertion_err:
415+
raise assertion_err from e.value
416+
417+
382418
def test_max_num_imprecise_acc_limit():
383419

384420
@triton.jit

python/triton/language/semantic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
14731473
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
14741474

14751475
if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1476+
# We upcast because there's no fp8e4b15 type in MLIR
14761477
lhs = cast(lhs, tl.float16, builder)
14771478
rhs = cast(rhs, tl.float16, builder)
14781479

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def min_dot_size(target: GPUTarget):
1818
# CDNA 3.0 supports k==8 in all mfma variants except for int8
1919
# (where the smallest `k` supported is 16)
2020
if "gfx94" in arch_str:
21-
return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.is_int8() or rhsType.is_int8()) else (16, 16, 8)
21+
return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else (
22+
16, 16, 8)
2223
# CDNA 2.0 always supports `k==8`
2324
if "gfx9" in arch_str:
2425
return lambda lhsType, rhsType: (16, 16, 8)

third_party/nvidia/backend/compiler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@
1717

1818

1919
def min_dot_size(target: GPUTarget):
20-
return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16)
20+
21+
def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
22+
lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
23+
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
24+
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
25+
if lhs_bitwidth == 8:
26+
return (16, 16, 32)
27+
else:
28+
return (16, 16, 16)
29+
30+
return check_dot_compatibility
2131

2232

2333
@functools.lru_cache()

0 commit comments

Comments
 (0)