Skip to content

Commit 9a01fe9

Browse files
authored
[TMA] Fix lowering TMA load when 2 users of differing encodings (#7398)
Recently a user reported a crash during TMA lowering in a kernel which roughly looks like ``` y = Y_desc.load([offset, 0]) for d_offset in tl.range(0, D, BLOCK_D): x = X_desc.load([offset, d_offset]) xt = tl.trans(x) acc = tl.dot(xt, y) out += tl.dot(x, tl.dot(xt, y).to(dtype)) ``` The error shows up as ``` error: operand #0 does not dominate this use xt = tl.trans(x) ``` Here's a minimized version of the faulty ttgir: ``` %36 = "ttg.local_alloc"() ... %39 = "ttg.local_alloc"(%48) %40 = "ttg.memdesc_trans"(%39) <{order = array<i32: 1, 0>}> ... %48 = "ttg.local_load"(%36) "scf.yield"(%47#0) : (tensor<64x64xf32, #mma1>) -> () ``` --- In `replaceUsesAndPropagateType` there're 2 places where the insertion point is changed, but only one of them is scoped by an insertion guard. The fix in this PR is to scope the other one, too. I've included a lit test to validate the fixed behavior, too.
1 parent 6a38bee commit 9a01fe9

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,7 @@ void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices) {
14671467
namespace mlir::triton {
14681468
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14691469
Value val) {
1470+
OpBuilder::InsertionGuard guard(builder);
14701471
SmallVector<Operation *> opsToDelete;
14711472
SmallVector<OpOperand *> operandsToReplace;
14721473

@@ -1487,7 +1488,6 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14871488

14881489
Operation *user = use.getOwner();
14891490
// `subview(old_op)` is replaced by a new `subview(val)`.
1490-
OpBuilder::InsertionGuard g(builder);
14911491
builder.setInsertionPoint(user);
14921492
Value newVal;
14931493
if (auto subview = dyn_cast<ttg::MemDescSubviewOp>(user)) {

test/TritonNvidiaGPU/tma_lowering.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,42 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
120120
tt.return %0, %1 : tensor<64x64xf32, #blocked>, !ttg.memdesc<64x64xf32, #shared, #smem, mutable>
121121
}
122122
}
123+
124+
// -----
125+
126+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
127+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
128+
#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
129+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}>
130+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
131+
#shared2 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [0, 1]}>
132+
#smem = #ttg.shared_memory
133+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
134+
// CHECK-LABEL: @tma_load_double_use
135+
tt.func public @tma_load_double_use(%arg0: !tt.tensordesc<tensor<64x32xf32, #shared>>, %arg1: !tt.tensordesc<tensor<64x64xf32, #shared1>>) -> tensor<64x32xf32, #mma1> {
136+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
137+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma1>
138+
%c32_i32 = arith.constant 32 : i32
139+
%c64_i32 = arith.constant 64 : i32
140+
// CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<64x32xf32
141+
%0 = tt.descriptor_load %arg0[%c64_i32, %c32_i32] : !tt.tensordesc<tensor<64x32xf32, #shared>> -> tensor<64x32xf32, #blocked>
142+
// CHECK: %[[B:.+]] = ttg.local_load %[[A]]
143+
// CHECK: %[[C:.+]] = ttg.local_alloc %[[B]]
144+
%1 = ttg.local_alloc %0 : (tensor<64x32xf32, #blocked>) -> !ttg.memdesc<64x32xf32, #shared1, #smem>
145+
// CHECK: %[[D:.+]] = ttg.memdesc_trans %[[C]]
146+
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<64x32xf32, #shared1, #smem> -> !ttg.memdesc<32x64xf32, #shared2, #smem>
147+
%3 = ttg.local_alloc %0 : (tensor<64x32xf32, #blocked>) -> !ttg.memdesc<64x32xf32, #shared, #smem>
148+
// CHECK: %[[E:.+]] = ttg.local_load %[[D]]
149+
%4 = ttg.local_load %2 : !ttg.memdesc<32x64xf32, #shared2, #smem> -> tensor<32x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
150+
// CHECK: %[[F:.+]] = ttg.local_load %[[A]]
151+
%5 = ttg.local_load %3 : !ttg.memdesc<64x32xf32, #shared, #smem> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
152+
// CHECK: %[[G:.+]] = tt.dot %[[E]], %[[F]]
153+
%6 = tt.dot %4, %5, %cst, inputPrecision = tf32 : tensor<32x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
154+
// CHECK: %[[H:.+]] = ttg.local_alloc %[[G]]
155+
%7 = ttg.local_alloc %6 : (tensor<32x32xf32, #mma>) -> !ttg.memdesc<32x32xf32, #shared, #smem>
156+
// CHECK: {{.*}} = ttng.warp_group_dot %[[A]], %[[H]]
157+
%8 = ttng.warp_group_dot %3, %7, %cst_0 {isAsync = true} : !ttg.memdesc<64x32xf32, #shared, #smem> * !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<64x32xf32, #mma1>
158+
%9:3 = ttng.warp_group_dot_wait %8, %3, %7 {pendings = 0 : i32} : tensor<64x32xf32, #mma1>, !ttg.memdesc<64x32xf32, #shared, #smem>, !ttg.memdesc<32x32xf32, #shared, #smem>
159+
tt.return %9 : tensor<64x32xf32, #mma1>
160+
}
161+
}

0 commit comments

Comments
 (0)