Skip to content

Commit e5e0081

Browse files
authored
[FRONTEND] Improve errors for TMA desc failure (#8462)
catch empty tensors and also print out all the argument when there is an unexpected failure
1 parent 5745035 commit e5e0081

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

python/test/unit/cuda/test_tma_descriptor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_1d_tma_descriptor_exception(M, BLOCK_M, expect_error):
2323
_ = TensorDescriptor.from_tensor(x, [BLOCK_M])
2424

2525

26-
@pytest.mark.parametrize("M, BLOCK_M, expect_error_m", [(128, 32, False), (125, 33, True)])
26+
@pytest.mark.parametrize("M, BLOCK_M, expect_error_m", [(128, 32, False), (125, 33, True), (0, 32, False)])
2727
@pytest.mark.parametrize("N, BLOCK_N, expect_error_n", [(128, 32, False), (128, 30, True), (127, 32, False)])
2828
def test_2d_tma_descriptor_exception(M, N, BLOCK_M, BLOCK_N, expect_error_n, expect_error_m):
2929
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 9:
@@ -39,10 +39,14 @@ def test_2d_tma_descriptor_exception(M, N, BLOCK_M, BLOCK_N, expect_error_n, exp
3939

4040
shape_error = expect_error_n or expect_error_m
4141
error_alignment = (N % 16) != 0
42-
expect_error = shape_error or error_alignment
42+
zero_shape_error = M <= 0 or N <= 0
43+
expect_error = shape_error or error_alignment or zero_shape_error
4344

4445
exc_type = ValueError if shape_error else AssertionError
4546
match = "Shape element . must be a power of 2" if shape_error else "strides must be 16-byte aligned"
47+
if zero_shape_error and not shape_error and not error_alignment:
48+
match = "shape must be positive"
49+
exc_type = AssertionError
4650
ctx = pytest.raises(exc_type, match=match) if expect_error else nullcontext()
4751
with ctx:
4852
_ = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_N])

python/triton/experimental/gluon/nvidia/hopper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def __post_init__(self):
2727
elem_bytes = get_primitive_bitwidth(dtype_str) // 8
2828
for stride in self.strides[:-1]:
2929
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
30+
for shape_dim in self.shape:
31+
assert shape_dim > 0, "shape must be positive"
3032
assert self.strides[-1] == 1, "Last dimension must be contiguous"
3133
assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
3234
assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"

python/triton/tools/tensor_descriptor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __post_init__(self):
2424
elem_bytes = self.base.dtype.itemsize
2525
for stride in self.strides[:-1]:
2626
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
27+
for shape_dim in self.shape:
28+
assert shape_dim > 0, "shape must be positive"
2729
assert self.strides[-1] == 1, "Last dimension must be contiguous"
2830
assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
2931
if self.padding == "nan":

third_party/nvidia/backend/driver.c

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "cuda.h"
22
#include <dlfcn.h>
33
#include <stdbool.h>
4+
#include <stdio.h>
45
#include <stdlib.h>
56
#define PY_SSIZE_T_CLEAN
67
#include <Python.h>
@@ -420,10 +421,53 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
420421
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
421422
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
422423
getCuTensorMapEncodeTiledHandle);
423-
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
424+
CUresult res = cuTensorMapEncodeTiled(
424425
&desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
425426
stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
426-
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill));
427+
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill);
428+
if (res != CUDA_SUCCESS) {
429+
const char *str;
430+
cuGetErrorString(res, &str);
431+
char err[4096] = {0};
432+
size_t off = 0;
433+
off += snprintf(
434+
err + off, sizeof(err) - off,
435+
"Triton Error [CUDA]: Failed to create tensor map descriptor: %s\n",
436+
str ? str : "Unknown error");
437+
off += snprintf(err + off, sizeof(err) - off,
438+
"elemType=%d rank=%d global_address=0x%llx elemSize=%d "
439+
"swizzle=%d padding=%d\n",
440+
elemType, rank, (unsigned long long)global_address,
441+
elemSize, swizzle, padding);
442+
off += snprintf(err + off, sizeof(err) - off, "shape=[");
443+
for (int i = 0; i < rank; ++i) {
444+
off +=
445+
snprintf(err + off, sizeof(err) - off, "%llu%s",
446+
(unsigned long long)shapeInt[i], (i + 1 < rank) ? ", " : "");
447+
}
448+
off += snprintf(err + off, sizeof(err) - off, "]\n");
449+
off += snprintf(err + off, sizeof(err) - off, "strides=[");
450+
for (int i = 0; i < rank; ++i) {
451+
off += snprintf(err + off, sizeof(err) - off, "%llu%s",
452+
(unsigned long long)stridesLL[i],
453+
(i + 1 < rank) ? ", " : "");
454+
}
455+
off += snprintf(err + off, sizeof(err) - off, "]\n");
456+
off += snprintf(err + off, sizeof(err) - off, "blockSize=[");
457+
for (int i = 0; i < rank; ++i) {
458+
off += snprintf(err + off, sizeof(err) - off, "%u%s",
459+
(unsigned)blockSizeInt[i], (i + 1 < rank) ? ", " : "");
460+
}
461+
off += snprintf(err + off, sizeof(err) - off, "] elementStrides=[");
462+
for (int i = 0; i < rank; ++i) {
463+
off += snprintf(err + off, sizeof(err) - off, "%u%s",
464+
(unsigned)elementStrides[i], (i + 1 < rank) ? ", " : "");
465+
}
466+
off += snprintf(err + off, sizeof(err) - off, "]\n");
467+
PyErr_SetString(PyExc_RuntimeError, err);
468+
469+
goto cleanup;
470+
}
427471

428472
return (PyObject *)desc;
429473

0 commit comments

Comments
 (0)