11#include " TargetInfo.h"
22#include " TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
3+ #include " Utility.h"
34#include " mlir/Analysis/TopologicalSortUtils.h"
45#include " mlir/Conversion/Passes.h"
56#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -87,6 +88,12 @@ static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx,
8788 std::optional<unsigned > numThreads, bool aligned) {
8889 assert (barIdx < 16 && " not enough barriers" );
8990
91+ // If a partition has only 1 warp, use `bar.warp.sync`.
92+ if (numThreads && *numThreads == 32 ) {
93+ LLVM::NVIDIA::createSyncWarp (b.getLoc (), b);
94+ return ;
95+ }
96+
9097 PTXBuilder ptxBuilder;
9198 std::string ptxString;
9299 llvm::raw_string_ostream os (ptxString);
@@ -101,6 +108,10 @@ static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx,
101108 ptxBuilder.launch (b, b.getLoc (), void_ty (b.getContext ()));
102109}
103110
111+ static void createAllBarrier (TritonLLVMIRRewriter &b, unsigned barIdx) {
112+ createBarrier (b, barIdx, /* numThreads=*/ std::nullopt , /* aligned=*/ false );
113+ }
114+
104115// ===----------------------------------------------------------------------===//
105116// elideTrivialCaptures
106117// ===----------------------------------------------------------------------===//
@@ -268,14 +279,12 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
268279
269280 // The shared memory is only live for the entry into the region, so put
270281 // another barrier here.
271- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
272- /* aligned=*/ false );
282+ createAllBarrier (b, kSwitchLoopBarrierIdx );
273283
274284 // Rewrite all warp returns.
275285 partition->walk ([&](WarpReturnOp op) {
276286 TritonLLVMIRRewriter b (op.getLoc (), op);
277- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
278- /* aligned=*/ false );
287+ createAllBarrier (b, kSwitchLoopBarrierIdx );
279288 if (auto actRegs = ws.getActualRegisters ()) {
280289 createRegRealloc (b, (*actRegs)[partition->getRegionNumber () + 1 ],
281290 lowRegs);
@@ -393,8 +402,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
393402 b.setInsertionPointToStart (switchLoop);
394403 if (maxnreg)
395404 createRegRealloc (b, maxnreg.getInt (), lowRegs);
396- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
397- /* aligned=*/ false );
405+ createAllBarrier (b, kSwitchLoopBarrierIdx );
398406 Value statePtr = LLVM::getSharedMemoryBase (b.getLoc (), b, targetInfo, func);
399407 Value relWid = b.sub (wid, b.i32_val (defaultNumWarps));
400408
@@ -448,10 +456,8 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
448456 Block *defaultBlock = new Block;
449457 funcBlocks.insert (std::next (switchLoop->getIterator ()), defaultBlock);
450458 b.setInsertionPointToStart (defaultBlock);
451- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
452- /* aligned=*/ false );
453- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
454- /* aligned=*/ false );
459+ createAllBarrier (b, kSwitchLoopBarrierIdx );
460+ createAllBarrier (b, kSwitchLoopBarrierIdx );
455461 auto latchBr = b.create <LLVM::BrOp>(switchLoop);
456462 disableLICM (latchBr);
457463
@@ -498,18 +504,15 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
498504
499505 // First barrier releases the waiting warpgroups. The second barrier ensures
500506 // they have read the captures before the memory is released upon entry.
501- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
502- /* aligned=*/ false );
507+ createAllBarrier (b, kSwitchLoopBarrierIdx );
503508 if (auto actRegs = ws.getActualRegisters ())
504509 createRegRealloc (b, defRegs, actRegs->front ());
505- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
506- /* aligned=*/ false );
510+ createAllBarrier (b, kSwitchLoopBarrierIdx );
507511 b.create <LLVM::BrOp>(&ws.getDefaultRegion ().front ());
508512
509513 ws.getDefaultRegion ().walk ([&, ws = ws](WarpYieldOp op) mutable {
510514 TritonLLVMIRRewriter b (op.getLoc (), op);
511- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
512- /* aligned=*/ false );
515+ createAllBarrier (b, kSwitchLoopBarrierIdx );
513516 if (auto actRegs = ws.getActualRegisters ())
514517 createRegRealloc (b, actRegs->front (), defRegs);
515518 b.replaceOpWithNewOp <LLVM::BrOp>(op, op.getOperands (), after);
@@ -532,8 +535,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
532535 Value cst = b.i8_val (partitionStateCounter);
533536 for (int32_t i : llvm::seq (maxNumWarps))
534537 b.store (cst, b.gep (ptrTy, i8_ty, statePtr, LLVM::GEPArg (i)));
535- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
536- /* aligned=*/ false );
538+ createAllBarrier (b, kSwitchLoopBarrierIdx );
537539 });
538540 b.setInsertionPointToStart (switchExit);
539541 b.create <LLVM::ReturnOp>(ValueRange ());
0 commit comments