11#include " ReduceScanCommon.h"
2- #include " mlir/Dialect/LLVMIR/NVVMDialect.h"
32#include " mlir/Support/LLVM.h"
43#include " triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
54#include " triton/Conversion/TritonGPUToLLVM/Utility.h"
6- #include " triton/Dialect/TritonGPU/Transforms/Utility.h"
7- #include < vector>
85
96using namespace mlir ;
107using namespace mlir ::triton;
@@ -80,36 +77,16 @@ struct ReduceOpConversion
8077private:
8178 const TargetInfoBase &targetInfo;
8279
83- void accumulate (ConversionPatternRewriter &rewriter, Region &combineOp,
84- SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
85- if (isFirst) {
86- acc = SmallVector<Value>(cur.begin (), cur.end ());
87- return ;
80+ void accumulate (Location loc, ConversionPatternRewriter &rewriter,
81+ Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
82+ Value pred = {}) const {
83+ auto results = applyCombineOp (loc, rewriter, combineOp, acc, cur, pred);
84+ if (acc.size () < results.size ()) {
85+ acc.resize (results.size ());
8886 }
89-
90- // Create a new copy of the reduce block, and inline it
91- Block *currentBlock = rewriter.getBlock ();
92- Region &parent = *currentBlock->getParent ();
93- rewriter.cloneRegionBefore (combineOp, &parent.front ());
94- auto &newReduce = parent.front ();
95- auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator ());
96-
97- llvm::SmallVector<Value> combineArgs (2 * acc.size ());
98- for (unsigned i = 0 ; i < acc.size (); ++i) {
99- combineArgs[i] = acc[i];
100- combineArgs[acc.size () + i] = cur[i];
101- }
102-
103- rewriter.inlineBlockBefore (&newReduce, &*rewriter.getInsertionPoint (),
104- combineArgs);
105-
106- auto results = returnOp.getResult ();
10787 for (unsigned i = 0 ; i < acc.size (); ++i) {
10888 acc[i] = results[i];
10989 }
110-
111- // Delete the terminator, which is no longer used
112- rewriter.eraseOp (returnOp);
11390 }
11491
11592 SmallVector<SmallVector<Value>>
@@ -165,7 +142,7 @@ struct ReduceOpConversion
165142 SmallVector<unsigned > key = offsets[i];
166143 key[op.getAxis ()] = 0 ;
167144 bool isFirst = accs.find (key) == accs.end ();
168- accumulate (rewriter, *combineOp, accs[key], srcValues[i], isFirst );
145+ accumulate (op. getLoc (), rewriter, *combineOp, accs[key], srcValues[i]);
169146 if (isFirst)
170147 indices[key] = srcIndices[i];
171148 }
@@ -175,17 +152,29 @@ struct ReduceOpConversion
175152 // region and the accumulator values as source.
176153 void warpReduce (ConversionPatternRewriter &rewriter, Location loc,
177154 SmallVector<Value> &acc, triton::ReduceOp op,
178- unsigned numLaneToReduce, unsigned interleave) const {
155+ unsigned numLaneToReduce, unsigned interleave,
156+ Value pred = {}) const {
179157 auto success = targetInfo.warpReduce (rewriter, loc, acc, op,
180158 numLaneToReduce, interleave);
181159 if (success)
182160 return ;
161+
162+ auto mod = op->getParentOfType <ModuleOp>();
163+ unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp (mod);
164+ if (iWarpSize > numLaneToReduce) {
165+ Value threadId = getThreadId (rewriter, loc);
166+ Value warpSize = i32_val (iWarpSize);
167+ Value laneId = urem (threadId, warpSize);
168+ Value lanePred = icmp_slt (laneId, i32_val (numLaneToReduce));
169+ pred = pred ? and_ (pred, lanePred) : lanePred;
170+ }
171+
183172 for (unsigned N = numLaneToReduce / 2 ; N > 0 ; N >>= 1 ) {
184173 SmallVector<Value> shfl (acc.size ());
185174 for (unsigned i = 0 ; i < acc.size (); ++i) {
186175 shfl[i] = targetInfo.shuffleXor (rewriter, loc, acc[i], N * interleave);
187176 }
188- accumulate (rewriter, op.getCombineOp (), acc, shfl, false );
177+ accumulate (op. getLoc (), rewriter, op.getCombineOp (), acc, shfl, pred );
189178 }
190179 }
191180
@@ -344,7 +333,8 @@ struct ReduceOpConversion
344333 acc[i] = targetInfo.loadShared (rewriter, loc, readPtr, elemTy,
345334 threadIsNeeded);
346335 }
347- warpReduce (rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */ );
336+ warpReduce (rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */ ,
337+ threadIsNeeded);
348338 // only the first thread in each sizeInterWarps is writing
349339 Value writeOffset = readOffset;
350340 SmallVector<Value> writePtrs (op.getNumOperands ());
0 commit comments