Skip to content

Commit 185ad59

Browse files
Get file, line and col from MLIR instead of inspecting python frames. Function name will be always set to "unknown" for now. (#4797)
Currently, `device_assert` takes in separate `filename`, `funcname`, `lineno` and it gets them by inspecting python frames in core.py. It's clunky and in practice it is often unknown. Since the loc information is already in the builder, it could be used directly, which is what this PR does.
1 parent 615bae8 commit 185ad59

File tree

7 files changed

+25
-39
lines changed

7 files changed

+25
-39
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,11 +844,11 @@ def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrit
844844
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
845845
let summary = "Device-side assert, as in CUDA for correctness checking";
846846
let description = [{
847-
`tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number.
847+
`tt.assert` takes a condition tensor and a message string.
848848
If the condition is false, the message is printed, and the program is aborted.
849849
}];
850-
let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line);
851-
let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)";
850+
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
851+
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
852852
}
853853

854854
//

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,30 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3434
return failure();
3535
}
3636
}
37-
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
38-
adaptor.getFunc(), adaptor.getLine(), rewriter);
37+
llAssert(op, condition, adaptor.getMessage(), rewriter);
3938
rewriter.eraseOp(op);
4039
return success();
4140
}
4241
// op: the op at which the assert is inserted. Unlike printf, we need to
4342
// know about the op to split the block.
4443
void llAssert(Operation *op, Value condition, StringRef message,
45-
StringRef file, StringRef func, int line,
4644
ConversionPatternRewriter &rewriter) const {
4745
ConversionPatternRewriter::InsertionGuard guard(rewriter);
46+
4847
auto ctx = rewriter.getContext();
4948
auto loc = op->getLoc();
49+
50+
StringRef file = "unknown";
51+
StringRef func = "unknown";
52+
int line = 0;
53+
int col = 0;
54+
55+
if (auto fileLineColLoc = dyn_cast<FileLineColLoc>(loc)) {
56+
file = fileLineColLoc.getFilename();
57+
line = fileLineColLoc.getLine();
58+
col = fileLineColLoc.getColumn();
59+
}
60+
5061
// #block1
5162
// if (condition) {
5263
// #block2

python/src/ir.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,17 +1546,10 @@ void init_triton_ir(py::module &&m) {
15461546
})
15471547
.def("create_assert",
15481548
[](TritonOpBuilder &self, Value &condition,
1549-
const std::string &message, const std::string &fileName,
1550-
const std::string &funcName, unsigned lineNo) -> void {
1549+
const std::string &message) -> void {
15511550
auto messageAttr = StringAttr::get(self.getBuilder().getContext(),
15521551
llvm::StringRef(message));
1553-
auto fileNameAttr = StringAttr::get(self.getBuilder().getContext(),
1554-
llvm::StringRef(fileName));
1555-
auto funcNameAttr = StringAttr::get(self.getBuilder().getContext(),
1556-
llvm::StringRef(funcName));
1557-
auto lineNoAttr = self.getBuilder().getI32IntegerAttr(lineNo);
1558-
self.create<AssertOp>(condition, messageAttr, fileNameAttr,
1559-
funcNameAttr, lineNoAttr);
1552+
self.create<AssertOp>(condition, messageAttr);
15601553
})
15611554
.def("create_assume",
15621555
[](TritonOpBuilder &self, Value &condition) {

python/triton/language/core.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,25 +2308,7 @@ def device_assert(cond, msg="", _builder=None):
23082308
:param msg: the message to print if the assertion fails. This is required to be a string literal.
23092309
'''
23102310
msg = _constexpr_to_value(msg)
2311-
import inspect
2312-
frame = inspect.currentframe()
2313-
module = inspect.getmodule(frame)
2314-
# The triton function module doesn't have the name attribute.
2315-
# We use this trick to find the caller.
2316-
while hasattr(module, "__name__"):
2317-
frame = frame.f_back
2318-
module = inspect.getmodule(frame)
2319-
lineno = 0
2320-
func_name = 'unknown'
2321-
file_name = 'unknown'
2322-
if frame is not None and frame.f_back is not None:
2323-
func_name = frame.f_code.co_name
2324-
file_name = frame.f_back.f_code.co_filename
2325-
# TODO: The line number currently indicates the line
2326-
# where the triton function is called but not where the
2327-
# device_assert is called. Need to enhance this.
2328-
lineno = frame.f_back.f_lineno
2329-
return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
2311+
return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder)
23302312

23312313

23322314
@builtin

python/triton/language/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,12 +1631,12 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
16311631
return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void)
16321632

16331633

1634-
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
1634+
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
16351635
cond_ty = cond.type
16361636
if not cond_ty.is_block():
16371637
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
16381638
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
1639-
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
1639+
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
16401640

16411641

16421642
def assume(cond, builder: ir.builder) -> tl.tensor:

python/triton/runtime/interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,9 +643,9 @@ def create_print(self, prefix, hex, values, isSigned):
643643
if hex:
644644
np.set_printoptions(formatter=None)
645645

646-
def create_assert(self, condition, message, fileName, funcName, lineNo):
646+
def create_assert(self, condition, message):
647647
# Interpreter's device_assert function has a different format than Triton's device_assert
648-
assert condition, f"{message} in {fileName}:{funcName}:{lineNo}"
648+
assert condition, f"{message}"
649649

650650
def create_assume(self, condition):
651651
assert condition, "Assume failed"

test/TritonGPU/combine.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2307,7 +2307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 :
23072307
tt.func @assertop(%ptr: tensor<1024x!tt.ptr<i1>, #blocked>) {
23082308
%0 = tt.load %ptr : tensor<1024x!tt.ptr<i1>, #blocked>
23092309
%1 = triton_gpu.convert_layout %0 : tensor<1024xi1, #blocked> -> tensor<1024xi1, #blocked1>
2310-
tt.assert %1, "cond must be true ", "unknown", "unknown", 0 : tensor<1024xi1, #blocked1>
2310+
tt.assert %1, "cond must be true " : tensor<1024xi1, #blocked1>
23112311
tt.return
23122312
}
23132313
}

0 commit comments

Comments
 (0)