Skip to content

Commit 5b6bf5d

Browse files
authored
[Frontend] Improve error when descriptor atomic_{min,max} cannot fallback (#6865)
Currently we fail in the middle of the tensor descriptor rewrite pass, whereas this pre-empts it and raises the error from the frontend.
1 parent 5eee385 commit 5b6bf5d

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,8 +1563,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
15631563
fallback_supported = dtype in FALLBACK_SUPPORTED_REDUCE_DTYPES[kind]
15641564
supported = native_supported if is_native else fallback_supported
15651565
if not supported:
1566-
exc_type = CompilationError if not native_supported else RuntimeError
1567-
with pytest.raises(exc_type):
1566+
with pytest.raises(CompilationError):
15681567
kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas)
15691568
return
15701569

python/triton/language/semantic.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import List, Optional, Sequence, Tuple, TypeVar
55
import numbers
66

7+
from triton.runtime import driver
8+
79
from .._C.libtriton import ir
810
from . import core as tl
911

@@ -1180,17 +1182,28 @@ def descriptor_atomic_add(desc: tl.tensor_descriptor_base, value: tl.tensor, off
11801182
return tl.tensor(builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
11811183

11821184

1185+
def _has_native_tma():
1186+
target = driver.active.get_current_target()
1187+
return (target.backend == "cuda" and target.arch >= 90)
1188+
1189+
1190+
def _descriptor_atomic_min_max_supported(dtype):
1191+
assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype"
1192+
if dtype in {tl.float16, tl.bfloat16}:
1193+
assert _has_native_tma(), "16-bit float types require native tma support"
1194+
1195+
11831196
def descriptor_atomic_min(desc: tl.tensor_descriptor_base, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
11841197
validate_store_like(desc, value, offsets)
1185-
assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype"
1198+
_descriptor_atomic_min_max_supported(desc.dtype)
11861199
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
11871200
kind = ir.DESCRIPTOR_REDUCE_KIND.MIN
11881201
return tl.tensor(builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
11891202

11901203

11911204
def descriptor_atomic_max(desc: tl.tensor_descriptor_base, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
11921205
validate_store_like(desc, value, offsets)
1193-
assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype"
1206+
_descriptor_atomic_min_max_supported(desc.dtype)
11941207
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
11951208
kind = ir.DESCRIPTOR_REDUCE_KIND.MAX
11961209
return tl.tensor(builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)

0 commit comments

Comments
 (0)