@@ -18,6 +18,20 @@ struct AllocateWarpGroups
1818 void runOnOperation () override {
1919 ModuleOp mod = getOperation ();
2020
21+ int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp (mod);
22+
23+ struct WarpGroupInfo {
24+ SmallVector<Region *> partitions;
25+ int maxRequestedRegs = 0 ;
26+ unsigned numWarps = 0 ;
27+ };
28+ struct WarpGroupPartition {
29+ int startId;
30+ Region *partition;
31+ int32_t estRegs;
32+ int numWarps;
33+ };
34+
2135 // Compute the total number of warps required at any given time.
2236 int baseNumWarps = lookupNumWarps (mod);
2337 int maxExtraWarps = 0 ;
@@ -42,6 +56,81 @@ struct AllocateWarpGroups
4256 startId += size;
4357 }
4458 op.setWarpGroupStartIds (startIds);
59+
60+ // Require that an estimate has been set and that we have even warpgroups.
61+ auto regsAttr = op.getRequestedRegisters ();
62+ if (!regsAttr || op.getTotalPartitionWarps () % 4 != 0 )
63+ return ;
64+
65+ // Group the partitions into warpgroups.
66+ SmallVector<WarpGroupPartition> orderedPartitions;
67+ for (auto [startId, partition, estRegs, numWarps] :
68+ llvm::zip (startIds, op.getPartitionRegions (), *regsAttr, arr))
69+ orderedPartitions.push_back ({startId, partition, estRegs, numWarps});
70+ llvm::sort (orderedPartitions,
71+ [&](auto lhs, auto rhs) { return lhs.startId < rhs.startId ; });
72+
73+ // Iterate over the partitions and assign them to warp groups. Determine
74+ // the maximum number of requested registers per warp group.
75+ SmallVector<WarpGroupInfo> warpGroups;
76+ for (auto [startId, partition, estRegs, numWarps] : orderedPartitions) {
77+ if (startId % 4 == 0 ) {
78+ warpGroups.push_back (WarpGroupInfo{});
79+ }
80+ warpGroups.back ().partitions .push_back (partition);
81+ // Round up the nearest multiple of 8.
82+ int estRegsCeil8 = llvm::divideCeil (estRegs, 8 ) * 8 ;
83+ warpGroups.back ().maxRequestedRegs =
84+ std::max<int >(warpGroups.back ().maxRequestedRegs , estRegsCeil8);
85+ warpGroups.back ().numWarps += numWarps;
86+ }
87+
88+ // Determine the maximum number of registers per thread. This may have
89+ // been set by the user.
90+ int maxnreg;
91+ if (auto maxnregAttr =
92+ op->getAttrOfType <IntegerAttr>(AttrMaxRegistersName)) {
93+ maxnreg = maxnregAttr.getInt ();
94+ } else {
95+ maxnreg = (1 << 16 ) / (baseNumWarps + op.getTotalPartitionWarps ()) /
96+ threadsPerWarp;
97+ maxnreg = maxnreg / 8 * 8 ;
98+ }
99+
100+ // Compute the register deficit over the partition warp groups.
101+ int registerDeficit = 0 ;
102+ for (const WarpGroupInfo &wg : warpGroups) {
103+ assert (wg.numWarps % 4 == 0 );
104+ registerDeficit +=
105+ (maxnreg - wg.maxRequestedRegs ) * wg.numWarps * threadsPerWarp;
106+ }
107+ if (registerDeficit <= 0 )
108+ return ;
109+
110+ // Determine the number of extra registers that we can distribute to the
111+ // default warp group.
112+ int leftover =
113+ ((baseNumWarps * threadsPerWarp * maxnreg) + registerDeficit) /
114+ baseNumWarps / threadsPerWarp;
115+ // Round down to the nearest multiple of 8.
116+ leftover = leftover / 8 * 8 ;
117+
118+ // Generate setmaxnreg in each partition according to its warp group.
119+ SmallVector<int32_t > maxnregsPerPartition (1 + arr.size ());
120+ for (const WarpGroupInfo &wg : warpGroups) {
121+ for (Region *region : wg.partitions ) {
122+ maxnregsPerPartition[1 + region->getRegionNumber ()] =
123+ wg.maxRequestedRegs ;
124+ }
125+ }
126+ // Set the register usage for the default warp group.
127+ maxnregsPerPartition.front () = leftover;
128+ op.setActualRegisters (maxnregsPerPartition);
129+
130+ // Set the initial max number of registers. This is needed for PTXAS to
131+ // cooperate.
132+ mod->setAttr (AttrMaxRegistersName,
133+ Builder (op.getContext ()).getI32IntegerAttr (maxnreg));
45134 });
46135
47136 Builder b (&getContext ());
0 commit comments