Skip to content

Commit 5c30929

Browse files
ThomasRaouxliuyunqi20
authored andcommitted
Fix assert loc for cases where assert is in an inlined func (#4840)
1 parent e9a226c commit 5c30929

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
5252
int line = 0;
5353
int col = 0;
5454

55+
while (auto callLoc = dyn_cast<CallSiteLoc>(loc))
56+
loc = callLoc.getCallee();
57+
5558
if (auto fileLineColLoc = dyn_cast<FileLineColLoc>(loc)) {
5659
file = fileLineColLoc.getFilename();
5760
line = fileLineColLoc.getLine();

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,3 +1693,21 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
16931693
tt.return
16941694
}
16951695
}
1696+
1697+
// -----
1698+
1699+
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
1700+
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32}
1701+
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32}
1702+
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32}
1703+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
1704+
tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) {
1705+
tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)
1706+
tt.return
1707+
}
1708+
}
1709+
#loc1 = loc("outer_call":33:8)
1710+
#loc2 = loc("top_func":47:8)
1711+
#loc3 = loc("inner_call":29:28)
1712+
#loc4 = loc(callsite(#loc3 at #loc1))
1713+
#loc5 = loc(callsite(#loc4 at #loc2))

0 commit comments

Comments
 (0)