Skip to content

Commit 1f53afc

Browse files
authored
[TensorDescriptor] Improve error from creating invalid descriptor (#7028)
This adds validation when creating a `TensorDescriptor` object in python, before it's been passed to the kernel. This not only improves the error message from the generic "invalid argument", it also means the stack trace will point to the code causing the error rather than the bowels of the kernel launch function.
1 parent 38f8167 commit 1f53afc

File tree

6 files changed

+109
-94
lines changed

6 files changed

+109
-94
lines changed

python/test/unit/cuda/test_tma_descriptor.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
from contextlib import nullcontext
22
import pytest
33
import torch
4-
import triton
54
from triton.tools.tensor_descriptor import TensorDescriptor
65

76

8-
@triton.jit
9-
def dummy_kernel(desc):
10-
pass
11-
12-
137
@pytest.mark.parametrize("M, BLOCK_M, expect_error", [(128, 32, False), (127, 32, False), (128, 31, True)])
148
def test_1d_tma_descriptor_exception(M, BLOCK_M, expect_error):
159
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 9:
@@ -22,10 +16,9 @@ def test_1d_tma_descriptor_exception(M, BLOCK_M, expect_error):
2216
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY
2317
assert x.data_ptr() % 16 == 0
2418

25-
desc = TensorDescriptor.from_tensor(x, [BLOCK_M])
2619
ctx = pytest.raises(ValueError, match="Shape element 0 must be a power of 2") if expect_error else nullcontext()
2720
with ctx:
28-
dummy_kernel[(1, )](desc)
21+
_ = TensorDescriptor.from_tensor(x, [BLOCK_M])
2922

3023

3124
@pytest.mark.parametrize("M, BLOCK_M, expect_error_m", [(128, 32, False), (125, 33, True)])
@@ -42,14 +35,12 @@ def test_2d_tma_descriptor_exception(M, N, BLOCK_M, BLOCK_N, expect_error_n, exp
4235
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY
4336
assert A.data_ptr() % 16 == 0
4437

45-
desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_N])
46-
4738
shape_error = expect_error_n or expect_error_m
4839
error_alignment = (N % 16) != 0
4940
expect_error = shape_error or error_alignment
5041

51-
exc_type = ValueError if shape_error else RuntimeError
52-
match = "Shape element . must be a power of 2" if shape_error else "Triton Error \\[CUDA\\]: invalid argument"
42+
exc_type = ValueError if shape_error else AssertionError
43+
match = "Shape element . must be a power of 2" if shape_error else "strides must be 16-byte aligned"
5344
ctx = pytest.raises(exc_type, match=match) if expect_error else nullcontext()
5445
with ctx:
55-
dummy_kernel[(1, )](desc)
46+
_ = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_N])

python/triton/_utils.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

33
from functools import reduce
4-
from typing import Any, Callable, TYPE_CHECKING, Union
4+
from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict
55

66
if TYPE_CHECKING:
77
from .language import core
88
IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
99
ObjPath = tuple[int, ...]
1010

11+
TRITON_MAX_TENSOR_NUMEL = 1048576
12+
1113

1214
def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
1315
return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
@@ -35,3 +37,88 @@ def _impl(path: tuple[int, ...], current: Any):
3537
_impl((), iterable)
3638

3739
return list(ret.keys())
40+
41+
42+
def is_power_of_two(x):
43+
return (x & (x - 1)) == 0
44+
45+
46+
def validate_block_shape(shape: List[int]):
47+
numel = 1
48+
for i, d in enumerate(shape):
49+
if not isinstance(d, int):
50+
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
51+
if not is_power_of_two(d):
52+
raise ValueError(f"Shape element {i} must be a power of 2")
53+
numel *= d
54+
55+
if numel > TRITON_MAX_TENSOR_NUMEL:
56+
raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
57+
return numel
58+
59+
60+
type_canonicalisation_dict = {
61+
# we canonicalise all bools to be unsigned:
62+
"bool": "u1",
63+
"int1": "u1",
64+
"uint1": "u1",
65+
"i1": "u1",
66+
# floating-point dtypes:
67+
"float8e4nv": "fp8e4nv",
68+
"float8e5": "fp8e5",
69+
"float8e4b15": "fp8e4b15",
70+
"float8_e4m3fn": "fp8e4nv",
71+
"float8e4b8": "fp8e4b8",
72+
"float8_e4m3fnuz": "fp8e4b8",
73+
"float8_e5m2": "fp8e5",
74+
"float8e5b16": "fp8e5b16",
75+
"float8_e5m2fnuz": "fp8e5b16",
76+
"half": "fp16",
77+
"float16": "fp16",
78+
"bfloat16": "bf16",
79+
"float": "fp32",
80+
"float32": "fp32",
81+
"double": "fp64",
82+
"float64": "fp64",
83+
# signed integers:
84+
"int8": "i8",
85+
"int16": "i16",
86+
"int": "i32",
87+
"int32": "i32",
88+
"int64": "i64",
89+
# unsigned integers:
90+
"uint8": "u8",
91+
"uint16": "u16",
92+
"uint32": "u32",
93+
"uint64": "u64",
94+
"void": "void",
95+
}
96+
97+
for v in list(type_canonicalisation_dict.values()):
98+
type_canonicalisation_dict[v] = v
99+
100+
101+
def canonicalize_dtype(dtype):
102+
dtype_str = str(dtype).split(".")[-1]
103+
return type_canonicalisation_dict[dtype_str]
104+
105+
106+
BITWIDTH_DICT: Dict[str, int] = {
107+
**{f"u{n}": n
108+
for n in (1, 8, 16, 32, 64)},
109+
**{f"i{n}": n
110+
for n in (1, 8, 16, 32, 64)},
111+
**{f"fp{n}": n
112+
for n in (16, 32, 64)},
113+
**{f"fp8{suffix}": 8
114+
for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")},
115+
"bf16": 16,
116+
"void": 0,
117+
}
118+
119+
for k, v in type_canonicalisation_dict.items():
120+
BITWIDTH_DICT[k] = BITWIDTH_DICT[v]
121+
122+
123+
def get_primitive_bitwidth(dtype: str) -> int:
124+
return BITWIDTH_DICT[dtype]

python/triton/language/_utils.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

python/triton/language/core.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .._C.libtriton import ir
1616
from . import semantic
17-
from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
17+
from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth
1818

1919
T = TypeVar('T')
2020

@@ -402,55 +402,43 @@ def __init__(self, name):
402402
name = _unwrap_if_constexpr(name)
403403
self.name = name
404404
assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
405+
self.primitive_bitwidth = get_primitive_bitwidth(name)
405406
if name in dtype.SINT_TYPES:
406407
self.int_signedness = dtype.SIGNEDNESS.SIGNED
407-
self.int_bitwidth = int(name.split('int')[-1])
408-
self.primitive_bitwidth = self.int_bitwidth
408+
self.int_bitwidth = self.primitive_bitwidth
409409
elif name in dtype.UINT_TYPES:
410410
self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
411-
self.int_bitwidth = int(name.split('int')[-1])
412-
self.primitive_bitwidth = self.int_bitwidth
411+
self.int_bitwidth = self.primitive_bitwidth
413412
elif name in dtype.FP_TYPES:
414413
if name == 'fp8e4b15':
415414
self.fp_mantissa_width = 3
416-
self.primitive_bitwidth = 8
417415
self.exponent_bias = 15
418416
elif name == 'fp8e4nv':
419417
self.fp_mantissa_width = 3
420-
self.primitive_bitwidth = 8
421418
self.exponent_bias = 7
422419
elif name == 'fp8e4b8':
423420
self.fp_mantissa_width = 3
424-
self.primitive_bitwidth = 8
425421
self.exponent_bias = 8
426422
elif name == 'fp8e5':
427423
self.fp_mantissa_width = 2
428-
self.primitive_bitwidth = 8
429424
self.exponent_bias = 15
430425
elif name == 'fp8e5b16':
431426
self.fp_mantissa_width = 2
432-
self.primitive_bitwidth = 8
433427
self.exponent_bias = 16
434428
elif name == 'fp16':
435429
self.fp_mantissa_width = 10
436-
self.primitive_bitwidth = 16
437430
self.exponent_bias = 15
438431
elif name == 'bf16':
439432
self.fp_mantissa_width = 7
440-
self.primitive_bitwidth = 16
441433
self.exponent_bias = 127
442434
elif name == 'fp32':
443435
self.fp_mantissa_width = 23
444-
self.primitive_bitwidth = 32
445436
self.exponent_bias = 127
446437
elif name == 'fp64':
447438
self.fp_mantissa_width = 52
448-
self.primitive_bitwidth = 64
449439
self.exponent_bias = 1023
450440
else:
451441
raise RuntimeError(f'Unsupported floating-point type {name}')
452-
elif name == 'void':
453-
self.primitive_bitwidth = 0
454442

455443
def is_fp8(self):
456444
return 'fp8' in self.name

python/triton/runtime/jit.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from types import ModuleType
1515
from .. import knobs
1616
from ..runtime.driver import driver
17-
from .._utils import find_paths_if, get_iterable_path
17+
from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype
1818

1919
TRITON_MODULE = __name__[:-len(".runtime.jit")]
2020

@@ -329,7 +329,7 @@ def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
329329
dsk = (arg.dtype, is_const)
330330
res = dtype2str.get(dsk, None)
331331
if res is None:
332-
res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
332+
res = ("*k" if dsk[1] else "*") + canonicalize_dtype(dsk[0])
333333
dtype2str[dsk] = res
334334
key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
335335
return (res, key)
@@ -347,7 +347,7 @@ def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
347347
return (tys, keys)
348348
elif isinstance(arg, TensorDescriptor):
349349
assert hasattr(arg.base, "data_ptr")
350-
inner = type_canonicalisation_dict[str(arg.base.dtype).split('.')[-1]]
350+
inner = canonicalize_dtype(arg.base.dtype)
351351
return (f"tensordesc<{inner}{list(arg.block_shape)}>", None)
352352
else:
353353
raise TypeError("Unsupported type: %s" % type(arg))
@@ -445,46 +445,6 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
445445
return func_namespace['dynamic_func']
446446

447447

448-
type_canonicalisation_dict = {
449-
# we canonicalise all bools to be unsigned:
450-
"bool": "u1",
451-
"int1": "u1",
452-
"uint1": "u1",
453-
"i1": "u1",
454-
# floating-point dtypes:
455-
"float8e4nv": "fp8e4nv",
456-
"float8e5": "fp8e5",
457-
"float8e4b15": "fp8e4b15",
458-
"float8_e4m3fn": "fp8e4nv",
459-
"float8e4b8": "fp8e4b8",
460-
"float8_e4m3fnuz": "fp8e4b8",
461-
"float8_e5m2": "fp8e5",
462-
"float8e5b16": "fp8e5b16",
463-
"float8_e5m2fnuz": "fp8e5b16",
464-
"half": "fp16",
465-
"float16": "fp16",
466-
"bfloat16": "bf16",
467-
"float": "fp32",
468-
"float32": "fp32",
469-
"double": "fp64",
470-
"float64": "fp64",
471-
# signed integers:
472-
"int8": "i8",
473-
"int16": "i16",
474-
"int": "i32",
475-
"int32": "i32",
476-
"int64": "i64",
477-
# unsigned integers:
478-
"uint8": "u8",
479-
"uint16": "u16",
480-
"uint32": "u32",
481-
"uint64": "u64",
482-
}
483-
484-
for v in list(type_canonicalisation_dict.values()):
485-
type_canonicalisation_dict[v] = v
486-
487-
488448
def get_full_name(fn):
489449
return f"{fn.__module__}.{fn.__qualname__}"
490450

python/triton/tools/tensor_descriptor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22
from typing import List, Any
3+
from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth
34

45

56
@dataclass
@@ -13,6 +14,15 @@ def __post_init__(self):
1314
rank = len(self.shape)
1415
assert len(self.strides) == rank, f"rank mismatch: {self}"
1516
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
17+
assert rank > 0, "rank must not be zero"
18+
assert rank <= 5, "rank cannot be more than 5"
19+
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
20+
validate_block_shape(self.block_shape)
21+
dtype_str = canonicalize_dtype(self.base.dtype)
22+
elem_bytes = get_primitive_bitwidth(dtype_str) // 8
23+
for stride in self.strides[:-1]:
24+
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
25+
assert self.strides[-1] == 1, "Last dimension must be contiguous"
1626

1727
@staticmethod
1828
def from_tensor(tensor: Any, block_shape: List[int]):

0 commit comments

Comments
 (0)