55#include " triton/Analysis/AxisInfo.h"
66#include " triton/Conversion/TritonToTritonGPU/Passes.h"
77#include " triton/Dialect/TritonGPU/Transforms/Passes.h"
8+ #include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
89#include " llvm/ADT/ScopeExit.h"
910
1011using namespace mlir ;
1112using namespace triton ;
1213using namespace triton ::gpu;
14+ namespace ttng = triton::nvidia_gpu;
1315
1416// ===----------------------------------------------------------------------===//
1517// relayoutWarps
@@ -182,14 +184,28 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
182184 // If the compiler could control that, then we could allow non-uniform
183185 // register distributions, mostly beneficial for single-warp warpgroups that
184186 // just do some artihmetic.
185- constexpr unsigned nTotalRegs = 65536 ; // for Blackwell SMs
187+ constexpr unsigned nTotalRegs = 1 << 16 ; // for Blackwell SMs
186188 const unsigned threadsPerWarp =
187189 TritonGPUDialect::getThreadsPerWarp (axisInfo.getModuleOp ());
188190 const unsigned defaultNumWarps = lookupNumWarps (wsOp);
189191
190192 SmallVector<int32_t > partitionNumWarps =
191193 llvm::to_vector (wsOp.getPartitionNumWarps ());
192194
195+ // Some instructions have critical throughput if have low register usage. Make
196+ // sure there are enough warps for these ops to execute quickly.
197+ SmallVector<int32_t > minWarpsForPartition (partitionNumWarps.size (), 1 );
198+ for (auto [minWarps, region] :
199+ llvm::zip (minWarpsForPartition, wsOp.getPartitionRegions ())) {
200+ region->walk ([minWarps = &minWarps](Operation *op) {
201+ if (!isa<scf::ForOp>(op->getParentOp ()))
202+ return ;
203+ if (isa<ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp,
204+ ttng::AsyncTMACopyGlobalToLocalOp>(op))
205+ *minWarps = 2 ;
206+ });
207+ }
208+
193209 bool changed;
194210 do {
195211 changed = false ;
@@ -215,9 +231,9 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
215231 int32_t curTotalNumWarps = std::accumulate (
216232 partitionNumWarps.begin (), partitionNumWarps.end (), defaultNumWarps);
217233
218- for (auto [numWarps, tensorRegs] :
219- llvm::zip (partitionNumWarps, maxTensorRegs)) {
220- if (numWarps == 1 )
234+ for (auto [minWarps, numWarps, tensorRegs] :
235+ llvm::zip (minWarpsForPartition, partitionNumWarps, maxTensorRegs)) {
236+ if (numWarps <= minWarps )
221237 continue ;
222238 // Check if reducing the number of warps will still fit the tensor. If it
223239 // didn't fit to begin with, it won't fit after shrinking.
@@ -233,16 +249,23 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
233249 }
234250 } while (changed);
235251
236- for (auto [partition, newNumWarps, prevNumWarps, tensorRegs] :
252+ SmallVector<int32_t > estRegUsage (partitionNumWarps.size ());
253+ for (auto [partition, newNumWarps, prevNumWarps, tensorRegs, estRegs] :
237254 llvm::zip (wsOp.getPartitionRegions (), partitionNumWarps,
238- wsOp.getPartitionNumWarps (), maxTensorRegs)) {
255+ wsOp.getPartitionNumWarps (), maxTensorRegs, estRegUsage)) {
256+ // "Guess" the register usage for each partition.
257+ estRegs = tensorRegs ? 80 : 48 ;
258+
259+ // Layouts need to be reassigned if the number of warps changed and there
260+ // are tensor computations.
239261 if (newNumWarps == prevNumWarps || !tensorRegs)
240262 continue ;
241263 // We need to reassign layouts.
242264 if (failed (relayoutWarps (axisInfo, partition, prevNumWarps, newNumWarps,
243265 runPipeline)))
244266 return failure ();
245267 }
268+ wsOp.setRequestedRegisters (estRegUsage);
246269 wsOp.setPartitionNumWarps (partitionNumWarps);
247270 return success ();
248271}
0 commit comments