1
1
#include " TargetInfo.h"
2
2
#include " TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
3
+ #include " Utility.h"
3
4
#include " mlir/Analysis/TopologicalSortUtils.h"
4
5
#include " mlir/Conversion/Passes.h"
5
6
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -87,6 +88,12 @@ static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx,
87
88
std::optional<unsigned > numThreads, bool aligned) {
88
89
assert (barIdx < 16 && " not enough barriers" );
89
90
91
+ // If a partition has only 1 warp, use `bar.warp.sync`.
92
+ if (numThreads && *numThreads == 32 ) {
93
+ LLVM::NVIDIA::createSyncWarp (b.getLoc (), b);
94
+ return ;
95
+ }
96
+
90
97
PTXBuilder ptxBuilder;
91
98
std::string ptxString;
92
99
llvm::raw_string_ostream os (ptxString);
@@ -101,6 +108,10 @@ static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx,
101
108
ptxBuilder.launch (b, b.getLoc (), void_ty (b.getContext ()));
102
109
}
103
110
111
+ static void createAllBarrier (TritonLLVMIRRewriter &b, unsigned barIdx) {
112
+ createBarrier (b, barIdx, /* numThreads=*/ std::nullopt, /* aligned=*/ false );
113
+ }
114
+
104
115
// ===----------------------------------------------------------------------===//
105
116
// elideTrivialCaptures
106
117
// ===----------------------------------------------------------------------===//
@@ -268,14 +279,12 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
268
279
269
280
// The shared memory is only live for the entry into the region, so put
270
281
// another barrier here.
271
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
272
- /* aligned=*/ false );
282
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
273
283
274
284
// Rewrite all warp returns.
275
285
partition->walk ([&](WarpReturnOp op) {
276
286
TritonLLVMIRRewriter b (op.getLoc (), op);
277
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
278
- /* aligned=*/ false );
287
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
279
288
if (auto actRegs = ws.getActualRegisters ()) {
280
289
createRegRealloc (b, (*actRegs)[partition->getRegionNumber () + 1 ],
281
290
lowRegs);
@@ -393,8 +402,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
393
402
b.setInsertionPointToStart (switchLoop);
394
403
if (maxnreg)
395
404
createRegRealloc (b, maxnreg.getInt (), lowRegs);
396
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
397
- /* aligned=*/ false );
405
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
398
406
Value statePtr = LLVM::getSharedMemoryBase (b.getLoc (), b, targetInfo, func);
399
407
Value relWid = b.sub (wid, b.i32_val (defaultNumWarps));
400
408
@@ -448,10 +456,8 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
448
456
Block *defaultBlock = new Block;
449
457
funcBlocks.insert (std::next (switchLoop->getIterator ()), defaultBlock);
450
458
b.setInsertionPointToStart (defaultBlock);
451
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
452
- /* aligned=*/ false );
453
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
454
- /* aligned=*/ false );
459
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
460
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
455
461
auto latchBr = b.create <LLVM::BrOp>(switchLoop);
456
462
disableLICM (latchBr);
457
463
@@ -498,18 +504,15 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
498
504
499
505
// First barrier releases the waiting warpgroups. The second barrier ensures
500
506
// they have read the captures before the memory is released upon entry.
501
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
502
- /* aligned=*/ false );
507
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
503
508
if (auto actRegs = ws.getActualRegisters ())
504
509
createRegRealloc (b, defRegs, actRegs->front ());
505
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
506
- /* aligned=*/ false );
510
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
507
511
b.create <LLVM::BrOp>(&ws.getDefaultRegion ().front ());
508
512
509
513
ws.getDefaultRegion ().walk ([&, ws = ws](WarpYieldOp op) mutable {
510
514
TritonLLVMIRRewriter b (op.getLoc (), op);
511
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
512
- /* aligned=*/ false );
515
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
513
516
if (auto actRegs = ws.getActualRegisters ())
514
517
createRegRealloc (b, actRegs->front (), defRegs);
515
518
b.replaceOpWithNewOp <LLVM::BrOp>(op, op.getOperands (), after);
@@ -532,8 +535,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
532
535
Value cst = b.i8_val (partitionStateCounter);
533
536
for (int32_t i : llvm::seq (maxNumWarps))
534
537
b.store (cst, b.gep (ptrTy, i8_ty, statePtr, LLVM::GEPArg (i)));
535
- createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt,
536
- /* aligned=*/ false );
538
+ createAllBarrier (b, kSwitchLoopBarrierIdx );
537
539
});
538
540
b.setInsertionPointToStart (switchExit);
539
541
b.create <LLVM::ReturnOp>(ValueRange ());
0 commit comments