Skip to content

Commit e196446

Browse files
authored
[FRONTEND] Support tensordesc in function call arguments (#6262)
This also does a few refactorings: 1. Refactor `mangle_ty(type)` -> `type.mangle()`, and implement for `tensor_descriptor_type` 2. Refactor `type.to_ir` -> `type._flatten_ir_types` which matches the `value._flatten_ir` method, but for types. 3. Updates function signature serialization and deserialization to use the new interfaces. Also, as part of debugging I updated `TRITON_FRONT_END_DEBUGGING=1` to disable wrapping exceptions in `CompilationError` which makes the stack trace point directly to the guts of the frontend making it far easier to debug.
1 parent dd6a540 commit e196446

File tree

6 files changed

+117
-54
lines changed

6 files changed

+117
-54
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
245245
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
246246
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1.
247247
- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`.
248+
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
248249

249250
**Kernel Override Steps**
250251

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,41 @@ def alloc_fn(size: int, align: int, stream: Optional[int]) -> torch.Tensor:
421421
torch.testing.assert_close(expect, out)
422422

423423

424+
@triton.jit(noinline=True)
425+
def tensor_descriptor_arg_helper(in_desc, out_desc, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
426+
moffset = tl.program_id(0) * M_BLOCK
427+
noffset = tl.program_id(1) * N_BLOCK
428+
value = in_desc.load([moffset, noffset])
429+
out_desc.store([moffset, noffset], value.abs())
430+
431+
432+
@requires_tma
433+
@pytest.mark.interpreter
434+
def test_tensor_descriptor_argument():
435+
436+
@triton.jit
437+
def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
438+
out_desc = tl.make_tensor_descriptor(out_ptr, shape=[M, N], strides=[N, 1], block_shape=[M_BLOCK, N_BLOCK])
439+
in_desc = tl.make_tensor_descriptor(a_ptr, shape=[M, N], strides=[N, 1], block_shape=[M_BLOCK, N_BLOCK])
440+
tensor_descriptor_arg_helper(in_desc, out_desc, M_BLOCK, N_BLOCK)
441+
442+
M, N = 32, 128
443+
inp = torch.randn((M, N), device="cuda")
444+
445+
M_BLOCK = 8
446+
N_BLOCK = 32
447+
out = inp.new_zeros((M, N))
448+
449+
def alloc_fn(size: int, align: int, stream: Optional[int]) -> torch.Tensor:
450+
return torch.empty(size, dtype=torch.int8, device="cuda")
451+
452+
triton.set_allocator(alloc_fn)
453+
454+
expect = inp.abs()
455+
kernel[(M // M_BLOCK, N // N_BLOCK)](out, inp, M, N, M_BLOCK, N_BLOCK)
456+
torch.testing.assert_close(expect, out)
457+
458+
424459
@triton.jit
425460
def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr, #
426461
M, N, K, #

python/triton/compiler/code_generator.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# ideally we wouldn't need any runtime component
1717
from ..runtime import JITFunction
1818
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
19+
from . import config
1920

2021
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
2122

@@ -27,29 +28,9 @@ def check_identifier_legality(name, type):
2728
return name
2829

2930

30-
def mangle_ty(ty):
31-
if ty.is_tuple():
32-
return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
33-
if ty.is_ptr():
34-
return 'P' + mangle_ty(ty.element_ty)
35-
if ty.is_int():
36-
SIGNED = language.dtype.SIGNEDNESS.SIGNED
37-
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
38-
return prefix + str(ty.int_bitwidth)
39-
if ty.is_floating():
40-
return str(ty)
41-
if ty.is_block():
42-
elt = mangle_ty(ty.scalar)
43-
shape = '_'.join(map(str, ty.shape))
44-
return f'{elt}S{shape}S'
45-
if ty.is_void():
46-
return 'V'
47-
raise TypeError(f'Unsupported type {ty}')
48-
49-
5031
def mangle_fn(name, arg_tys, constants):
5132
# doesn't mangle ret type, which must be a function of arg tys
52-
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
33+
mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
5334
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
5435
mangled_constants = mangled_constants.replace('.', '_d_')
5536
mangled_constants = mangled_constants.replace("'", '_sq_')
@@ -71,8 +52,8 @@ def _is_constexpr(o: Any) -> bool:
7152
return o is None or isinstance(o, (constexpr, language.core.dtype))
7253

7354

74-
def _is_triton_scalar(o: Any) -> bool:
75-
return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
55+
def _is_non_scalar_tensor(o: Any) -> bool:
56+
return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
7657

7758

7859
def _is_list_like(o: Any) -> bool:
@@ -82,7 +63,7 @@ def _is_list_like(o: Any) -> bool:
8263
def _check_fn_args(node, fn, args):
8364
if fn.noinline:
8465
for idx, arg in enumerate(args):
85-
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
66+
if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
8667
raise UnsupportedLanguageConstruct(
8768
fn.src, node,
8869
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
@@ -241,26 +222,26 @@ def __init__(self, ret_types, arg_types, constants, attrs):
241222
self.constants = constants
242223
self.attrs = attrs
243224

244-
def return_types_ir(self, builder: ir.builder):
245-
ret_types = []
246-
for ret_ty in self.ret_types:
247-
if ret_ty is None:
225+
def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
226+
ir_types = []
227+
for ty in types:
228+
if ty is None:
248229
continue
249-
ir_ty = ret_ty.to_ir(builder)
250-
if isinstance(ir_ty, list):
251-
ret_types.extend(ir_ty)
252-
else:
253-
ret_types.append(ir_ty)
254-
return ret_types
230+
ty._flatten_ir_types(builder, ir_types)
231+
return ir_types
232+
233+
def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
234+
return self.flatten_ir_types(builder, self.ret_types)
255235

256236
def serialize(self, builder: ir.builder):
257237
# fill up IR values in template
258238
# > build function
259239
is_val = lambda path, _: path not in self.constants and _ is not None
260240
val_paths = list(find_paths_if(self.arg_types, is_val))
261-
arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
262-
ret_types = self.return_types_ir(builder)
263-
return builder.get_function_ty(arg_types, ret_types)
241+
arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
242+
arg_types_ir = self.flatten_ir_types(builder, arg_types)
243+
ret_types_ir = self.return_types_ir(builder)
244+
return builder.get_function_ty(arg_types_ir, ret_types_ir)
264245

265246
def deserialize(self, fn):
266247
# create "template"
@@ -282,9 +263,12 @@ def make_template(ty):
282263
if isinstance(ty, nv_tma_desc_type):
283264
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
284265
# > add IR values to the template
285-
for i, path in enumerate(val_paths):
266+
cursor = 0
267+
handles = [fn.args(i) for i in range(fn.get_num_args())]
268+
for path in val_paths:
286269
ty = get_iterable_path(self.arg_types, path)
287-
set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
270+
val, cursor = ty._unflatten_ir(handles, cursor)
271+
set_iterable_path(vals, path, val)
288272
# > add constexpr values to the template
289273
constants = self.constants
290274
for path, val in constants.items():
@@ -1218,14 +1202,16 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12181202
generator.visit(fn.parse())
12191203
except Exception as e:
12201204
# Wrap the error in the callee with the location of the call.
1205+
if config.front_end_debugging():
1206+
raise
12211207
raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
12221208

12231209
callee_ret_type = generator.ret_type
12241210
self.function_ret_types[fn_name] = callee_ret_type
12251211
else:
12261212
callee_ret_type = self.function_ret_types[fn_name]
12271213
symbol = self.module.get_function(fn_name)
1228-
args_val = [arg.handle for arg in args_val]
1214+
args_val = flatten_values_to_ir(args_val)
12291215
call_op = self.builder.call(symbol, args_val)
12301216
if callee_ret_type == language.void:
12311217
return None
@@ -1256,6 +1242,8 @@ def visit_Call(self, node):
12561242
ret = language.tuple(ret)
12571243
return ret
12581244
except Exception as e:
1245+
if config.front_end_debugging():
1246+
raise
12591247
# Normally when we raise a CompilationError, we raise it as
12601248
# `from None`, because the original fileline from the exception
12611249
# is not relevant (and often points into code_generator.py
@@ -1335,6 +1323,8 @@ def visit(self, node):
13351323
except CompilationError:
13361324
raise
13371325
except Exception as e:
1326+
if config.front_end_debugging():
1327+
raise
13381328
# Wrap the error in a CompilationError which contains the source
13391329
# of the @jit function.
13401330
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None

python/triton/compiler/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..tools.disasm import get_sass
1212
# TODO: this shouldn't be here
1313
from .code_generator import ast_to_ttir
14+
from . import config
1415
from pathlib import Path
1516
import re
1617
import functools
@@ -179,7 +180,7 @@ def filter_traceback(e: BaseException):
179180
180181
These are uninteresting to the user -- "just show me *my* code!"
181182
"""
182-
if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1":
183+
if config.front_end_debugging():
183184
return
184185

185186
if e.__cause__ is not None:

python/triton/compiler/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import os
2+
3+
4+
def front_end_debugging():
5+
return os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1"

python/triton/language/core.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_valu
307307
"""
308308
raise NotImplementedError
309309

310+
def mangle(self) -> str:
311+
raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
312+
313+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
314+
raise NotImplementedError
315+
310316

311317
# -----------------------
312318
# dtype
@@ -502,10 +508,6 @@ def is_ptr():
502508
def is_const():
503509
return False
504510

505-
@staticmethod
506-
def is_tuple():
507-
return False
508-
509511
def __eq__(self, other: dtype):
510512
if not isinstance(other, dtype):
511513
return False
@@ -518,6 +520,9 @@ def __hash__(self):
518520
def scalar(self):
519521
return self
520522

523+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
524+
out.append(self.to_ir(builder))
525+
521526
def to_ir(self, builder: ir.builder) -> ir.type:
522527
if self.name.startswith("fp8"):
523528
if self.name not in builder.options.supported_fp8_dtypes:
@@ -581,6 +586,17 @@ def __repr__(self):
581586
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
582587
return tensor(handles[cursor], self), cursor + 1
583588

589+
def mangle(self) -> str:
590+
if self.is_int():
591+
SIGNED = dtype.SIGNEDNESS.SIGNED
592+
prefix = 'i' if self.int_signedness == SIGNED else 'u'
593+
return prefix + str(self.int_bitwidth)
594+
if self.is_floating():
595+
return str(self)
596+
if self.is_void():
597+
return 'V'
598+
return super().mangle()
599+
584600

585601
# Some functions have a param named `dtype`, which shadows the `dtype` class.
586602
# We can't change the param name because it is part of function's public API.
@@ -623,6 +639,9 @@ def __eq__(self, other: pointer_type) -> bool:
623639
def scalar(self):
624640
return self
625641

642+
def mangle(self) -> str:
643+
return f"P{self.element_ty.mangle()}"
644+
626645

627646
class nv_tma_desc_type(pointer_type):
628647

@@ -672,6 +691,11 @@ def __eq__(self, other) -> bool:
672691
def scalar(self):
673692
return self.element_ty
674693

694+
def mangle(self) -> str:
695+
elt = self.scalar.mangle()
696+
shape = '_'.join(map(str, self.shape))
697+
return f'{elt}S{shape}S'
698+
675699

676700
class tuple_type(base_type):
677701

@@ -686,15 +710,14 @@ def __str__(self):
686710
def __iter__(self):
687711
return iter(self.types)
688712

689-
def to_ir(self, builder: ir.builder):
690-
return [ty.to_ir(builder) for ty in self.types]
713+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]):
714+
for ty in self.types:
715+
if not isinstance(ty, constexpr):
716+
ty._flatten_ir_types(builder, out)
691717

692718
def __getitem__(self, index: int) -> dtype:
693719
return self.types[index]
694720

695-
def is_tuple(self):
696-
return True
697-
698721
def __eq__(self, other):
699722
return type(self) is type(other) and self.types == other.types and self.fields == other.fields
700723

@@ -705,6 +728,9 @@ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, in
705728
values.append(value)
706729
return tuple(values, self), cursor
707730

731+
def mangle(self):
732+
return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T'
733+
708734

709735
class slice_type(dtype):
710736

@@ -1263,8 +1289,8 @@ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_de
12631289
value = tensor_descriptor_base(handles[cursor], self.block_type)
12641290
return value, cursor + 1
12651291

1266-
def to_ir(self, builder: ir.builder):
1267-
return builder.create_tensor_descriptor_type(self.block_type.to_ir(builder))
1292+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1293+
out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder)))
12681294

12691295
def __str__(self) -> str:
12701296
# ex. "tensor_descriptor<float32[16, 32]>"
@@ -1278,6 +1304,9 @@ def __eq__(self, other) -> bool:
12781304
def __neq__(self, other) -> bool:
12791305
return not (self == other)
12801306

1307+
def mangle(self) -> str:
1308+
return f"TD{self.block_type.mangle()}"
1309+
12811310

12821311
class tensor_descriptor_base(base_value):
12831312
""""
@@ -1363,8 +1392,10 @@ def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_de
13631392
value = tensor_descriptor(handle, shape, strides, self.block_type)
13641393
return value, cursor
13651394

1366-
def to_ir(self, builder: ir.builder):
1367-
return [super().to_ir(builder), *self.shape_type.to_ir(builder), *self.strides_type.to_ir(builder)]
1395+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1396+
super()._flatten_ir_types(builder, out)
1397+
self.shape_type._flatten_ir_types(builder, out)
1398+
self.strides_type._flatten_ir_types(builder, out)
13681399

13691400
def __eq__(self, other):
13701401
return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type

0 commit comments

Comments
 (0)