Skip to content

Commit f8ac214

Browse files
Merge commit 'e19644610878ab19ec375d3eabfb464398aade32'
2 parents 59e291d + e196446 commit f8ac214

File tree

23 files changed

+193
-484
lines changed

23 files changed

+193
-484
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
245245
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
246246
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1.
247247
- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`.
248+
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
248249

249250
**Kernel Override Steps**
250251

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8585
// TritonAMDGPUToLLVM passes
8686
mlir::triton::registerConvertTritonAMDGPUToLLVM();
8787
mlir::triton::registerConvertBuiltinFuncToLLVM();
88-
mlir::triton::registerDecomposeUnsupportedAMDConversions();
8988
mlir::triton::registerOptimizeAMDLDSUsage();
9089

9190
// TritonAMDGPUTransforms passes

include/triton/Conversion/TritonGPUToLLVM/Patterns.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ add_triton_library(TritonGPUToLLVM
66
AssertOpToLLVM.cpp
77
ControlFlowOpToLLVM.cpp
88
ConvertLayoutOpToLLVM.cpp
9-
DecomposeUnsupportedConversions.cpp
109
ElementwiseOpToLLVM.cpp
1110
FuncOpToLLVM.cpp
1211
GatherOpToLLVM.cpp

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 0 additions & 96 deletions
This file was deleted.

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,41 @@ def alloc_fn(size: int, align: int, stream: Optional[int]) -> torch.Tensor:
421421
torch.testing.assert_close(expect, out)
422422

423423

424+
@triton.jit(noinline=True)
425+
def tensor_descriptor_arg_helper(in_desc, out_desc, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
426+
moffset = tl.program_id(0) * M_BLOCK
427+
noffset = tl.program_id(1) * N_BLOCK
428+
value = in_desc.load([moffset, noffset])
429+
out_desc.store([moffset, noffset], value.abs())
430+
431+
432+
@requires_tma
433+
@pytest.mark.interpreter
434+
def test_tensor_descriptor_argument():
435+
436+
@triton.jit
437+
def kernel(out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
438+
out_desc = tl.make_tensor_descriptor(out_ptr, shape=[M, N], strides=[N, 1], block_shape=[M_BLOCK, N_BLOCK])
439+
in_desc = tl.make_tensor_descriptor(a_ptr, shape=[M, N], strides=[N, 1], block_shape=[M_BLOCK, N_BLOCK])
440+
tensor_descriptor_arg_helper(in_desc, out_desc, M_BLOCK, N_BLOCK)
441+
442+
M, N = 32, 128
443+
inp = torch.randn((M, N), device="cuda")
444+
445+
M_BLOCK = 8
446+
N_BLOCK = 32
447+
out = inp.new_zeros((M, N))
448+
449+
def alloc_fn(size: int, align: int, stream: Optional[int]) -> torch.Tensor:
450+
return torch.empty(size, dtype=torch.int8, device="cuda")
451+
452+
triton.set_allocator(alloc_fn)
453+
454+
expect = inp.abs()
455+
kernel[(M // M_BLOCK, N // N_BLOCK)](out, inp, M, N, M_BLOCK, N_BLOCK)
456+
torch.testing.assert_close(expect, out)
457+
458+
424459
@triton.jit
425460
def matmul_kernel_make_tensor_desciptor(a_ptr, b_ptr, c_ptr, #
426461
M, N, K, #

python/triton/compiler/code_generator.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# ideally we wouldn't need any runtime component
1717
from ..runtime import JITFunction
1818
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
19+
from . import config
1920

2021
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
2122

@@ -27,29 +28,9 @@ def check_identifier_legality(name, type):
2728
return name
2829

2930

30-
def mangle_ty(ty):
31-
if ty.is_tuple():
32-
return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T'
33-
if ty.is_ptr():
34-
return 'P' + mangle_ty(ty.element_ty)
35-
if ty.is_int():
36-
SIGNED = language.dtype.SIGNEDNESS.SIGNED
37-
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
38-
return prefix + str(ty.int_bitwidth)
39-
if ty.is_floating():
40-
return str(ty)
41-
if ty.is_block():
42-
elt = mangle_ty(ty.scalar)
43-
shape = '_'.join(map(str, ty.shape))
44-
return f'{elt}S{shape}S'
45-
if ty.is_void():
46-
return 'V'
47-
raise TypeError(f'Unsupported type {ty}')
48-
49-
5031
def mangle_fn(name, arg_tys, constants):
5132
# doesn't mangle ret type, which must be a function of arg tys
52-
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
33+
mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
5334
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
5435
mangled_constants = mangled_constants.replace('.', '_d_')
5536
mangled_constants = mangled_constants.replace("'", '_sq_')
@@ -71,8 +52,8 @@ def _is_constexpr(o: Any) -> bool:
7152
return o is None or isinstance(o, (constexpr, language.core.dtype))
7253

7354

74-
def _is_triton_scalar(o: Any) -> bool:
75-
return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
55+
def _is_non_scalar_tensor(o: Any) -> bool:
56+
return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
7657

7758

7859
def _is_list_like(o: Any) -> bool:
@@ -82,7 +63,7 @@ def _is_list_like(o: Any) -> bool:
8263
def _check_fn_args(node, fn, args):
8364
if fn.noinline:
8465
for idx, arg in enumerate(args):
85-
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
66+
if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
8667
raise UnsupportedLanguageConstruct(
8768
fn.src, node,
8869
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
@@ -241,26 +222,26 @@ def __init__(self, ret_types, arg_types, constants, attrs):
241222
self.constants = constants
242223
self.attrs = attrs
243224

244-
def return_types_ir(self, builder: ir.builder):
245-
ret_types = []
246-
for ret_ty in self.ret_types:
247-
if ret_ty is None:
225+
def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
226+
ir_types = []
227+
for ty in types:
228+
if ty is None:
248229
continue
249-
ir_ty = ret_ty.to_ir(builder)
250-
if isinstance(ir_ty, list):
251-
ret_types.extend(ir_ty)
252-
else:
253-
ret_types.append(ir_ty)
254-
return ret_types
230+
ty._flatten_ir_types(builder, ir_types)
231+
return ir_types
232+
233+
def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
234+
return self.flatten_ir_types(builder, self.ret_types)
255235

256236
def serialize(self, builder: ir.builder):
257237
# fill up IR values in template
258238
# > build function
259239
is_val = lambda path, _: path not in self.constants and _ is not None
260240
val_paths = list(find_paths_if(self.arg_types, is_val))
261-
arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
262-
ret_types = self.return_types_ir(builder)
263-
return builder.get_function_ty(arg_types, ret_types)
241+
arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
242+
arg_types_ir = self.flatten_ir_types(builder, arg_types)
243+
ret_types_ir = self.return_types_ir(builder)
244+
return builder.get_function_ty(arg_types_ir, ret_types_ir)
264245

265246
def deserialize(self, fn):
266247
# create "template"
@@ -282,9 +263,12 @@ def make_template(ty):
282263
if isinstance(ty, nv_tma_desc_type):
283264
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
284265
# > add IR values to the template
285-
for i, path in enumerate(val_paths):
266+
cursor = 0
267+
handles = [fn.args(i) for i in range(fn.get_num_args())]
268+
for path in val_paths:
286269
ty = get_iterable_path(self.arg_types, path)
287-
set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
270+
val, cursor = ty._unflatten_ir(handles, cursor)
271+
set_iterable_path(vals, path, val)
288272
# > add constexpr values to the template
289273
constants = self.constants
290274
for path, val in constants.items():
@@ -1218,14 +1202,16 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12181202
generator.visit(fn.parse())
12191203
except Exception as e:
12201204
# Wrap the error in the callee with the location of the call.
1205+
if config.front_end_debugging():
1206+
raise
12211207
raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
12221208

12231209
callee_ret_type = generator.ret_type
12241210
self.function_ret_types[fn_name] = callee_ret_type
12251211
else:
12261212
callee_ret_type = self.function_ret_types[fn_name]
12271213
symbol = self.module.get_function(fn_name)
1228-
args_val = [arg.handle for arg in args_val]
1214+
args_val = flatten_values_to_ir(args_val)
12291215
call_op = self.builder.call(symbol, args_val)
12301216
if callee_ret_type == language.void:
12311217
return None
@@ -1256,6 +1242,8 @@ def visit_Call(self, node):
12561242
ret = language.tuple(ret)
12571243
return ret
12581244
except Exception as e:
1245+
if config.front_end_debugging():
1246+
raise
12591247
# Normally when we raise a CompilationError, we raise it as
12601248
# `from None`, because the original fileline from the exception
12611249
# is not relevant (and often points into code_generator.py
@@ -1335,6 +1323,8 @@ def visit(self, node):
13351323
except CompilationError:
13361324
raise
13371325
except Exception as e:
1326+
if config.front_end_debugging():
1327+
raise
13381328
# Wrap the error in a CompilationError which contains the source
13391329
# of the @jit function.
13401330
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None

python/triton/compiler/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..tools.disasm import get_sass, get_spvdis
1212
# TODO: this shouldn't be here
1313
from .code_generator import ast_to_ttir
14+
from . import config
1415
from pathlib import Path
1516
import re
1617
import functools
@@ -181,7 +182,7 @@ def filter_traceback(e: BaseException):
181182
182183
These are uninteresting to the user -- "just show me *my* code!"
183184
"""
184-
if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1":
185+
if config.front_end_debugging():
185186
return
186187

187188
if e.__cause__ is not None:

python/triton/compiler/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import os
2+
3+
4+
def front_end_debugging():
5+
return os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1"

0 commit comments

Comments
 (0)