1- #include < iterator>
2-
31#include " ReduceScanCommon.h"
42#include " mlir/Support/LLVM.h"
53#include " triton/Analysis/Utility.h"
@@ -16,37 +14,13 @@ using ::mlir::LLVM::linearize;
1614using ::mlir::triton::gpu::getTotalElemsPerThread;
1715
1816// apply combine region to acc and cur and accumulate it into acc
19- // TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce.
20- // Deduplicate
21- static SmallVector<Value> accumulate (ConversionPatternRewriter &rewriter,
22- Region &combineOp, ValueRange acc,
23- ValueRange cur) {
24- // Allows for passing an unitialized acc and use cur as the neutral element
25- if (acc.size () == 0 ) {
26- return cur;
27- }
28- assert (cur.size () == acc.size ());
29- // Create a new copy of the reduce block, and inline it
30- Block *currentBlock = rewriter.getBlock ();
31- Region &parent = *currentBlock->getParent ();
32- rewriter.cloneRegionBefore (combineOp, &parent.front ());
33- auto &newScan = parent.front ();
34- auto returnOp = dyn_cast<triton::ScanReturnOp>(newScan.getTerminator ());
35-
36- SmallVector<Value> combineArgs (2 * acc.size ());
37- for (unsigned i = 0 ; i < acc.size (); ++i) {
38- combineArgs[i] = acc[i];
39- combineArgs[acc.size () + i] = cur[i];
40- }
41-
42- rewriter.inlineBlockBefore (&newScan, &*rewriter.getInsertionPoint (),
43- combineArgs);
44- SmallVector<Value> results;
45- llvm::transform (returnOp.getResult (), std::back_inserter (results),
46- [&](Value res) { return rewriter.getRemappedValue (res); });
47- // Delete the terminator, which is no longer used
48- rewriter.eraseOp (returnOp);
49- return results;
17+ static SmallVector<Value> accumulate (ScanLoweringHelper &helper,
18+ ConversionPatternRewriter &rewriter,
19+ ValueRange acc, ValueRange cur,
20+ Value pred = {}) {
21+ auto loc = helper.getLoc ();
22+ auto &combineOp = helper.getCombineOp ();
23+ return applyCombineOp (loc, rewriter, combineOp, acc, cur, pred);
5024}
5125
5226// Scan a contiguous elements within a thread and update `srcValues` in place.
@@ -66,8 +40,8 @@ scanThreadContiguousElements(SmallVector<SmallVector<Value>> &srcValues,
6640 unsigned accIndex = (srcIndex % stride) +
6741 ((srcIndex / stride) / scanElementsPerThreads) * stride;
6842
69- accs[accIndex] = accumulate (rewriter, helper. getCombineOp (), accs[accIndex],
70- srcValues[srcIndex]);
43+ accs[accIndex] =
44+ accumulate (helper, rewriter, accs[accIndex], srcValues[srcIndex]);
7145 srcValues[srcIndex] = accs[accIndex];
7246 }
7347}
@@ -95,11 +69,11 @@ static void warpScan(SmallVector<SmallVector<Value>> &srcValues,
9569 for (unsigned j = 0 ; j < acc.size (); ++j) {
9670 shfl[j] = targetInfo.shuffleUp (rewriter, loc, acc[j], i * threadStride);
9771 }
72+ Value mask = icmp_sge (laneIdAxis, i32_val (i));
9873 SmallVector<Value> tempAcc =
99- accumulate (rewriter, helper.getCombineOp (), shfl, acc);
100- Value mask = icmp_slt (laneIdAxis, i32_val (i));
74+ accumulate (helper, rewriter, shfl, acc, mask);
10175 for (unsigned j = 0 ; j < acc.size (); ++j) {
102- acc[j] = select (mask, acc [j], tempAcc [j]);
76+ acc[j] = select (mask, tempAcc [j], acc [j]);
10377 }
10478 }
10579 srcValues[srcIndex] = acc;
@@ -164,9 +138,9 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
164138 unsigned elementStride = helper.getAxisElementStride ();
165139 unsigned threadStride = helper.getAxisThreadStride ();
166140 unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData ();
167- Value maskFirstWarp = icmp_eq (warpId, i32_val (0 ));
168- Value maskFirstLane = icmp_eq (laneIdAxis, i32_val (0 ));
169- Value maskFirstThread = and_ (maskFirstWarp, maskFirstLane );
141+ Value maskNotFirstWarp = icmp_ne (warpId, i32_val (0 ));
142+ Value maskNotFirstLane = icmp_ne (laneIdAxis, i32_val (0 ));
143+ Value maskNotFirstThread = or_ (maskNotFirstWarp, maskNotFirstLane );
170144 struct Accumulator {
171145 SmallVector<Value> acc;
172146 SmallVector<Value> maskedAcc;
@@ -212,42 +186,43 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
212186 accumulator.maskedAcc = partialReduce;
213187 continue ;
214188 }
215- accumulator. acc = accumulate (rewriter, helper. getCombineOp (),
216- accumulator.acc , partialReduce);
217- Value mask = icmp_slt (warpId, i32_val (i + 1 ) );
189+ Value mask = icmp_sge (warpId, i32_val (i + 1 ));
190+ accumulator.acc =
191+ accumulate (helper, rewriter, accumulator. acc , partialReduce, mask );
218192 for (unsigned j = 0 ; j < helper.getNumOperands (); ++j) {
219193 accumulator.maskedAcc [j] =
220- select (mask, accumulator.maskedAcc [j], accumulator.acc [j]);
194+ select (mask, accumulator.acc [j], accumulator.maskedAcc [j]);
221195 }
222196 }
223- auto temp = accumulate (rewriter, helper.getCombineOp (),
224- accumulator.maskedAcc , srcValues[srcIndex]);
197+
198+ Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{};
199+ auto temp = accumulate (helper, rewriter, accumulator.maskedAcc ,
200+ srcValues[srcIndex], pred);
225201 if (axisBlockId == 0 ) {
226202 // For the first warp and first chunk we don't have anything to
227203 // accumulate.
228204 auto val = srcValues[srcIndex];
229205 for (unsigned i = 0 ; i < helper.getNumOperands (); ++i) {
230- temp[i] = select (maskFirstWarp, val [i], temp [i]);
206+ temp[i] = select (maskNotFirstWarp, temp [i], val [i]);
231207 }
232208 }
233209 srcValues[srcIndex] = temp;
234210 // Update the rest of the contiguous elements.
235211 SmallVector<Value> lastElement (helper.getNumOperands ());
236212 for (unsigned i = 0 ; i < helper.getNumOperands (); ++i) {
237213 auto elem = targetInfo.shuffleUp (rewriter, loc, temp[i], threadStride);
238- lastElement[i] = select (maskFirstLane, accumulator.maskedAcc [i], elem );
214+ lastElement[i] = select (maskNotFirstLane, elem, accumulator.maskedAcc [i]);
239215 }
240216 for (unsigned i = 1 ; i < scanElementsPerThreads; ++i) {
217+ pred = axisBlockId == 0 ? maskNotFirstThread : Value{};
241218 auto laneValue = srcValues[srcIndex - i * elementStride];
242- laneValue =
243- accumulate (rewriter, helper.getCombineOp (), lastElement, laneValue);
219+ laneValue = accumulate (helper, rewriter, lastElement, laneValue, pred);
244220 if (axisBlockId == 0 ) {
245221 // For the first warp and first chunk we don't have anything to
246222 // accumulate.
247223 for (unsigned j = 0 ; j < helper.getNumOperands (); ++j) {
248- laneValue[j] =
249- select (maskFirstThread,
250- srcValues[srcIndex - i * elementStride][j], laneValue[j]);
224+ laneValue[j] = select (maskNotFirstThread, laneValue[j],
225+ srcValues[srcIndex - i * elementStride][j]);
251226 }
252227 }
253228 srcValues[srcIndex - i * elementStride] = laneValue;
@@ -300,8 +275,8 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
300275 if (axisBlockId == 0 ) // First chunk and first block
301276 accumulator = srcValues[srcIndex];
302277 else
303- srcValues[srcIndex] = accumulate (rewriter, helper. getCombineOp (),
304- accumulator, srcValues[srcIndex]);
278+ srcValues[srcIndex] =
279+ accumulate (helper, rewriter, accumulator, srcValues[srcIndex]);
305280 // Update the rest of the contiguous elements.
306281 auto lastElement = srcValues[srcIndex];
307282 if (scanDim > 1 ) {
@@ -319,8 +294,7 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
319294 }
320295 for (unsigned i = 1 ; i < scanElementsPerThreads; ++i) {
321296 auto laneValue = srcValues[srcIndex - i * elementStride];
322- laneValue =
323- accumulate (rewriter, helper.getCombineOp (), lastElement, laneValue);
297+ laneValue = accumulate (helper, rewriter, lastElement, laneValue);
324298 if (axisBlockId == 0 ) {
325299 for (unsigned j = 0 ; j < helper.getNumOperands (); ++j) {
326300 // For the first warp and first chunk we don't have anything to
0 commit comments