Skip to content

Commit cd57ce9

Browse files
authored
[Backend] Disable LLVM LICM on warp specialize switch loop (#6870)
This prevents LLVM fromt hoisting arbitrary number of values out of the switch loop, which become live across all partition regions. This can induce tons of spilling in warp specialized kernels.
1 parent 0559d9a commit cd57ce9

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

test/Conversion/warp_specialize_to_llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -convert-warp-specialize-to-llvm | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file -mlir-print-local-scope -allow-unregistered-dialect -convert-warp-specialize-to-llvm | FileCheck %s
22

33
module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 11 : i32} {
44

@@ -76,7 +76,7 @@ llvm.func @generate_switch_loop() attributes {allocation.offset = 32 : i32} {
7676
// CHECK: [[DEFAULT]]:
7777
// CHECK-NEXT: barrier.sync 1 ;
7878
// CHECK-NEXT: barrier.sync 1 ;
79-
// CHECK-NEXT: llvm.br [[SWITCH_LOOP]]
79+
// CHECK-NEXT: llvm.br [[SWITCH_LOOP]] {loop_annotation = #llvm.loop_annotation<licm = <disable = true>>}
8080

8181
// CHECK: [[EXIT]]:
8282
// CHECK-NEXT: llvm.return

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,23 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
294294
}
295295
}
296296

297+
// LLVM's LICM will be tempted to hoist code out of the switch loop generated by
298+
// the `ttg.warp_specialize` lowering. However, neither NVPTX or `ptxas` will
299+
// rematerialize this code back in to the partition regions, resulting in long
300+
// liveranges for an arbitrary number of registers.
301+
//
302+
// Due to reduced warp group registers, these live values can induce spilling
303+
// in the partition regions. Prevent this by disabling LICM on the switch loop.
304+
static void disableLICM(LLVM::BrOp latchBr) {
305+
Builder b(latchBr.getContext());
306+
MLIRContext *ctx = b.getContext();
307+
auto licmMD = LLVM::LoopLICMAttr::get(ctx, b.getBoolAttr(true), {});
308+
auto loopMD =
309+
LLVM::LoopAnnotationAttr::get(b.getContext(), {}, {}, {}, {}, {}, licmMD,
310+
{}, {}, {}, {}, {}, {}, {}, {}, {});
311+
latchBr.setLoopAnnotationAttr(loopMD);
312+
}
313+
297314
static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
298315
const NVIDIA::TargetInfo &targetInfo) {
299316
SmallVector<WarpSpecializeOp> wsOps;
@@ -415,7 +432,8 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
415432
/*aligned=*/false);
416433
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
417434
/*aligned=*/false);
418-
b.create<LLVM::BrOp>(switchLoop);
435+
auto latchBr = b.create<LLVM::BrOp>(switchLoop);
436+
disableLICM(latchBr);
419437

420438
// Exit state.
421439
Block *switchExit = new Block;

0 commit comments

Comments
 (0)