Skip to content

Commit 677a30c

Browse files
authored
[Backend] Emit bar.warp.sync for barriers of 1 warp (#7336)
In warp specialized regions with only 1 warp, we can emit `bar.warp.sync` instead of barriers with a threadcount. This is slightly more efficient.
1 parent 34a2120 commit 677a30c

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

test/Conversion/warp_specialize_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 :
88
llvm.func @rewrite_barriers() attributes {allocation.offset = 32 : i32} {
99
// CHECK: barrier.sync.aligned 2, 128
1010
// CHECK: barrier.sync.aligned 3, 64
11-
// CHECK: barrier.sync.aligned 4, 32
11+
// CHECK: bar.warp.sync
1212

1313
// CHECK: bb{{[0-9]+}}:
1414
// CHECK-NEXT: barrier.sync.aligned 0, 128

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "TargetInfo.h"
22
#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
3+
#include "Utility.h"
34
#include "mlir/Analysis/TopologicalSortUtils.h"
45
#include "mlir/Conversion/Passes.h"
56
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -87,6 +88,12 @@ static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx,
8788
std::optional<unsigned> numThreads, bool aligned) {
8889
assert(barIdx < 16 && "not enough barriers");
8990

91+
// If a partition has only 1 warp, use `bar.warp.sync`.
92+
if (numThreads && *numThreads == 32) {
93+
LLVM::NVIDIA::createSyncWarp(b.getLoc(), b);
94+
return;
95+
}
96+
9097
PTXBuilder ptxBuilder;
9198
std::string ptxString;
9299
llvm::raw_string_ostream os(ptxString);
@@ -101,6 +108,10 @@ static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx,
101108
ptxBuilder.launch(b, b.getLoc(), void_ty(b.getContext()));
102109
}
103110

111+
static void createAllBarrier(TritonLLVMIRRewriter &b, unsigned barIdx) {
112+
createBarrier(b, barIdx, /*numThreads=*/std::nullopt, /*aligned=*/false);
113+
}
114+
104115
//===----------------------------------------------------------------------===//
105116
// elideTrivialCaptures
106117
//===----------------------------------------------------------------------===//
@@ -268,14 +279,12 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
268279

269280
// The shared memory is only live for the entry into the region, so put
270281
// another barrier here.
271-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
272-
/*aligned=*/false);
282+
createAllBarrier(b, kSwitchLoopBarrierIdx);
273283

274284
// Rewrite all warp returns.
275285
partition->walk([&](WarpReturnOp op) {
276286
TritonLLVMIRRewriter b(op.getLoc(), op);
277-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
278-
/*aligned=*/false);
287+
createAllBarrier(b, kSwitchLoopBarrierIdx);
279288
if (auto actRegs = ws.getActualRegisters()) {
280289
createRegRealloc(b, (*actRegs)[partition->getRegionNumber() + 1],
281290
lowRegs);
@@ -393,8 +402,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
393402
b.setInsertionPointToStart(switchLoop);
394403
if (maxnreg)
395404
createRegRealloc(b, maxnreg.getInt(), lowRegs);
396-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
397-
/*aligned=*/false);
405+
createAllBarrier(b, kSwitchLoopBarrierIdx);
398406
Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func);
399407
Value relWid = b.sub(wid, b.i32_val(defaultNumWarps));
400408

@@ -448,10 +456,8 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
448456
Block *defaultBlock = new Block;
449457
funcBlocks.insert(std::next(switchLoop->getIterator()), defaultBlock);
450458
b.setInsertionPointToStart(defaultBlock);
451-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
452-
/*aligned=*/false);
453-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
454-
/*aligned=*/false);
459+
createAllBarrier(b, kSwitchLoopBarrierIdx);
460+
createAllBarrier(b, kSwitchLoopBarrierIdx);
455461
auto latchBr = b.create<LLVM::BrOp>(switchLoop);
456462
disableLICM(latchBr);
457463

@@ -498,18 +504,15 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
498504

499505
// First barrier releases the waiting warpgroups. The second barrier ensures
500506
// they have read the captures before the memory is released upon entry.
501-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
502-
/*aligned=*/false);
507+
createAllBarrier(b, kSwitchLoopBarrierIdx);
503508
if (auto actRegs = ws.getActualRegisters())
504509
createRegRealloc(b, defRegs, actRegs->front());
505-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
506-
/*aligned=*/false);
510+
createAllBarrier(b, kSwitchLoopBarrierIdx);
507511
b.create<LLVM::BrOp>(&ws.getDefaultRegion().front());
508512

509513
ws.getDefaultRegion().walk([&, ws = ws](WarpYieldOp op) mutable {
510514
TritonLLVMIRRewriter b(op.getLoc(), op);
511-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
512-
/*aligned=*/false);
515+
createAllBarrier(b, kSwitchLoopBarrierIdx);
513516
if (auto actRegs = ws.getActualRegisters())
514517
createRegRealloc(b, actRegs->front(), defRegs);
515518
b.replaceOpWithNewOp<LLVM::BrOp>(op, op.getOperands(), after);
@@ -532,8 +535,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
532535
Value cst = b.i8_val(partitionStateCounter);
533536
for (int32_t i : llvm::seq(maxNumWarps))
534537
b.store(cst, b.gep(ptrTy, i8_ty, statePtr, LLVM::GEPArg(i)));
535-
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
536-
/*aligned=*/false);
538+
createAllBarrier(b, kSwitchLoopBarrierIdx);
537539
});
538540
b.setInsertionPointToStart(switchExit);
539541
b.create<LLVM::ReturnOp>(ValueRange());

0 commit comments

Comments
 (0)