|
4 | 4 | from typing import List, Optional, Sequence, Tuple, TypeVar
|
5 | 5 | import numbers
|
6 | 6 |
|
| 7 | +from triton.runtime import driver |
| 8 | + |
7 | 9 | from .._C.libtriton import ir
|
8 | 10 | from . import core as tl
|
9 | 11 |
|
@@ -1180,17 +1182,28 @@ def descriptor_atomic_add(desc: tl.tensor_descriptor_base, value: tl.tensor, off
|
1180 | 1182 | return tl.tensor(builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
1181 | 1183 |
|
1182 | 1184 |
|
| 1185 | +def _has_native_tma(): |
| 1186 | + target = driver.active.get_current_target() |
| 1187 | + return (target.backend == "cuda" and target.arch >= 90) |
| 1188 | + |
| 1189 | + |
| 1190 | +def _descriptor_atomic_min_max_supported(dtype): |
| 1191 | + assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype" |
| 1192 | + if dtype in {tl.float16, tl.bfloat16}: |
| 1193 | + assert _has_native_tma(), "16-bit float types require native tma support" |
| 1194 | + |
| 1195 | + |
1183 | 1196 | def descriptor_atomic_min(desc: tl.tensor_descriptor_base, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
|
1184 | 1197 | validate_store_like(desc, value, offsets)
|
1185 |
| - assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype" |
| 1198 | + _descriptor_atomic_min_max_supported(desc.dtype) |
1186 | 1199 | offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
1187 | 1200 | kind = ir.DESCRIPTOR_REDUCE_KIND.MIN
|
1188 | 1201 | return tl.tensor(builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
1189 | 1202 |
|
1190 | 1203 |
|
1191 | 1204 | def descriptor_atomic_max(desc: tl.tensor_descriptor_base, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
|
1192 | 1205 | validate_store_like(desc, value, offsets)
|
1193 |
| - assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype" |
| 1206 | + _descriptor_atomic_min_max_supported(desc.dtype) |
1194 | 1207 | offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
|
1195 | 1208 | kind = ir.DESCRIPTOR_REDUCE_KIND.MAX
|
1196 | 1209 | return tl.tensor(builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
|
|
0 commit comments