Skip to content

Commit 0b9853e

Browse files
authored
[Warp Specialization] Fix iterator invalidation (#7223)
Tracking the liveUntilOp using `getNextNode` is unsafe because the next node could get replaced by the rewrite of another store. Track the before vs. after using a flag to avoid this.
1 parent 76b6977 commit 0b9853e

File tree

2 files changed

+66
-11
lines changed

2 files changed

+66
-11
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct PipelinedLoad {
4747

4848
SmallVector<Operation *, 1> allocOps;
4949
SmallVector<Operation *, 1> liveBeforeOps;
50-
SmallVector<Operation *, 0> liveUntilOps;
50+
SmallVector<std::pair<Operation *, bool>, 0> liveUntilOps;
5151
SmallVector<Operation *, 1> asyncUsers;
5252
};
5353

@@ -252,8 +252,6 @@ LogicalResult PipelinedLoad::determineLiveRange(Block &container,
252252
// memory must be live until after this operation.
253253
Operation *lastShmemSink =
254254
findNearestCommonPostDominator(shmemTerminals, postDomInfo);
255-
if (lastShmemSink)
256-
lastShmemSink = lastShmemSink->getNextNode();
257255

258256
// The memory only needs to be live until before the first register user.
259257
Operation *liveUntilReg = findNearestCommonDominator(regSink, domInfo);
@@ -262,14 +260,16 @@ LogicalResult PipelinedLoad::determineLiveRange(Block &container,
262260

263261
// The memory is live until before the first register user or after the last
264262
// shmem terminal, whichever is later.
265-
Operation *liveUntilOp;
263+
std::pair<Operation *, bool> liveUntilOp{nullptr, false};
266264
if (lastShmemSink && liveUntilReg) {
267-
liveUntilOp = liveUntilReg->isBeforeInBlock(lastShmemSink) ? lastShmemSink
268-
: liveUntilReg;
265+
if (liveUntilReg->isBeforeInBlock(lastShmemSink))
266+
liveUntilOp = {lastShmemSink, /*after=*/true};
267+
else
268+
liveUntilOp = {liveUntilReg, /*after=*/false};
269269
} else if (liveUntilReg) {
270-
liveUntilOp = liveUntilReg;
270+
liveUntilOp = {liveUntilReg, /*after=*/false};
271271
} else {
272-
liveUntilOp = lastShmemSink;
272+
liveUntilOp = {lastShmemSink, /*after=*/true};
273273
}
274274
liveUntilOps.push_back(liveUntilOp);
275275
}
@@ -316,7 +316,7 @@ void PipelinedLoadGroup::allocateAref(scf::ForOp &loop, int numStages) {
316316
for (PipelinedLoad &load : loads) {
317317
distinctAsyncUsers.insert(load.asyncUsers.begin(), load.asyncUsers.end());
318318
int numLiveUntil =
319-
llvm::count_if(load.liveUntilOps, [](Operation *op) { return !!op; });
319+
llvm::count_if(load.liveUntilOps, [](auto p) { return !!p.first; });
320320
maxLiveUntil = std::max(maxLiveUntil, numLiveUntil);
321321
}
322322
int arriveCount = distinctAsyncUsers.size() + maxLiveUntil;
@@ -390,8 +390,11 @@ LogicalResult PipelinedLoadGroup::lowerLoads(WarpSchedule &schedule,
390390

391391
SmallVector<Operation *> liveUntilOps;
392392
for (PipelinedLoad &load : loads) {
393-
if (Operation *liveUntilOp = load.liveUntilOps[i])
394-
liveUntilOps.push_back(liveUntilOp);
393+
auto [liveUntilOp, after] = load.liveUntilOps[i];
394+
if (liveUntilOp) {
395+
liveUntilOps.push_back(after ? liveUntilOp->getNextNode()
396+
: liveUntilOp);
397+
}
395398
}
396399
if (!liveUntilOps.empty()) {
397400
Operation *liveUntilOp =

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
// CHECK-DAG: [[ACC_TMEM:#.*]] = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
1414
#acc_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
1515

16+
#lhs_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
17+
#lhs_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = false>
18+
1619
#fp4_padded_shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, fp4Padded = true, CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>
1720

1821
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
@@ -1247,6 +1250,55 @@ tt.func @local_alloc_into_mma(
12471250
tt.return
12481251
}
12491252

1253+
// CHECK-LABEL: @shmem_sink_iterator_invalidation
1254+
// CHECK-SAME: [[A_DESC:%arg[0-9]+]]: !tt.tensordesc
1255+
// CHECK-SAME: [[B_DESC:%arg[0-9]+]]: !tt.tensordesc
1256+
tt.func @shmem_sink_iterator_invalidation(
1257+
%k_tiles: i32,
1258+
%off_m: i32,
1259+
%off_n: i32,
1260+
%a_desc: !tt.tensordesc<tensor<128x64xf16, #shared>>,
1261+
%b_desc: !tt.tensordesc<tensor<128x64xf16, #shared>>
1262+
) {
1263+
%true = arith.constant true
1264+
%c0_i32 = arith.constant 0 : i32
1265+
%c1_i32 = arith.constant 1 : i32
1266+
1267+
%BLOCK_K = arith.constant 64 : i32
1268+
%zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout>
1269+
1270+
%result = scf.for %k = %c0_i32 to %k_tiles step %c1_i32
1271+
iter_args(%acc = %zero) -> tensor<128x128xf32, #acc_layout> : i32 {
1272+
%off_k = arith.muli %k, %BLOCK_K : i32
1273+
1274+
// CHECK: async_tma_copy_global_to_local [[B_DESC]]
1275+
%b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #oper_layout>
1276+
// CHECK: wait_barrier [[B_EMPTY:%[0-9]+]]
1277+
// CHECK: async_tma_copy_global_to_local [[A_DESC]][{{.*}}] [[B_DEST:%[0-9]+]], [[B_BAR:%[0-9]+]]
1278+
%a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #oper_layout>
1279+
1280+
%a_shared = ttg.local_alloc %a_reg : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
1281+
// CHECK: wait_barrier [[B_BAR]]
1282+
// CHECK-NEXT: [[B:%.*]] = ttg.local_load [[B_DEST]]
1283+
// CHECK-NEXT: arrive_barrier [[B_EMPTY]]
1284+
// CHECK-NEXT: memdesc_trans
1285+
%a = ttg.local_load %a_shared : !ttg.memdesc<128x64xf16, #shared, #smem> -> tensor<128x64xf16, #lhs_layout>
1286+
%b_shared = ttg.local_alloc %b_reg : (tensor<128x64xf16, #oper_layout>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
1287+
%b_T_shared = ttg.memdesc_trans %b_shared {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared_trans, #smem>
1288+
%c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
1289+
%a_tmem = ttng.tmem_alloc %a : (tensor<128x64xf16, #lhs_layout>) -> !ttg.memdesc<128x64xf16, #lhs_tmem, #ttng.tensor_memory>
1290+
%mma_tok = ttng.tc_gen5_mma %a_tmem, %b_T_shared, %c_tmem[%c_tok], %true, %true : !ttg.memdesc<128x64xf16, #lhs_tmem, #ttng.tensor_memory>, !ttg.memdesc<64x128xf16, #shared_trans, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>
1291+
1292+
%c, %load_tok = ttng.tmem_load %c_tmem[%mma_tok] : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout>
1293+
1294+
scf.yield %c : tensor<128x128xf32, #acc_layout>
1295+
1296+
} {tt.warp_specialize, tt.num_stages = 2 : i32}
1297+
1298+
"use"(%result) : (tensor<128x128xf32, #acc_layout>) -> ()
1299+
tt.return
1300+
}
1301+
12501302
}
12511303

12521304
// -----

0 commit comments

Comments
 (0)