Skip to content

Commit 1b0f9ea

Browse files
authored
[Backend] Fix device assert inside reduction/scan region (#4811)
Currently the reduction codegen unconditionally executes the combine region which can create problems because we conditionally load from shared memory, so this uses uninitialized registers. Generally combine regions should be pure, so this shouldn't be observable but with the overflow sanitizer the frontend injects assertions into the combine region. This changes the `accumulate` function to take a predicate and if the combine region isn't speculateble we only run it on threads where the predicate is true. In the common case, the codegen is unchanged.
1 parent a70d585 commit 1b0f9ea

File tree

4 files changed

+193
-96
lines changed

4 files changed

+193
-96
lines changed

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
#include "ReduceScanCommon.h"
2-
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
32
#include "mlir/Support/LLVM.h"
43
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
54
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
6-
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
7-
#include <vector>
85

96
using namespace mlir;
107
using namespace mlir::triton;
@@ -80,36 +77,16 @@ struct ReduceOpConversion
8077
private:
8178
const TargetInfoBase &targetInfo;
8279

83-
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
84-
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
85-
if (isFirst) {
86-
acc = SmallVector<Value>(cur.begin(), cur.end());
87-
return;
80+
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
81+
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
82+
Value pred = {}) const {
83+
auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
84+
if (acc.size() < results.size()) {
85+
acc.resize(results.size());
8886
}
89-
90-
// Create a new copy of the reduce block, and inline it
91-
Block *currentBlock = rewriter.getBlock();
92-
Region &parent = *currentBlock->getParent();
93-
rewriter.cloneRegionBefore(combineOp, &parent.front());
94-
auto &newReduce = parent.front();
95-
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());
96-
97-
llvm::SmallVector<Value> combineArgs(2 * acc.size());
98-
for (unsigned i = 0; i < acc.size(); ++i) {
99-
combineArgs[i] = acc[i];
100-
combineArgs[acc.size() + i] = cur[i];
101-
}
102-
103-
rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
104-
combineArgs);
105-
106-
auto results = returnOp.getResult();
10787
for (unsigned i = 0; i < acc.size(); ++i) {
10888
acc[i] = results[i];
10989
}
110-
111-
// Delete the terminator, which is no longer used
112-
rewriter.eraseOp(returnOp);
11390
}
11491

11592
SmallVector<SmallVector<Value>>
@@ -165,7 +142,7 @@ struct ReduceOpConversion
165142
SmallVector<unsigned> key = offsets[i];
166143
key[op.getAxis()] = 0;
167144
bool isFirst = accs.find(key) == accs.end();
168-
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
145+
accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]);
169146
if (isFirst)
170147
indices[key] = srcIndices[i];
171148
}
@@ -175,17 +152,29 @@ struct ReduceOpConversion
175152
// region and the accumulator values as source.
176153
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
177154
SmallVector<Value> &acc, triton::ReduceOp op,
178-
unsigned numLaneToReduce, unsigned interleave) const {
155+
unsigned numLaneToReduce, unsigned interleave,
156+
Value pred = {}) const {
179157
auto success = targetInfo.warpReduce(rewriter, loc, acc, op,
180158
numLaneToReduce, interleave);
181159
if (success)
182160
return;
161+
162+
auto mod = op->getParentOfType<ModuleOp>();
163+
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
164+
if (iWarpSize > numLaneToReduce) {
165+
Value threadId = getThreadId(rewriter, loc);
166+
Value warpSize = i32_val(iWarpSize);
167+
Value laneId = urem(threadId, warpSize);
168+
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
169+
pred = pred ? and_(pred, lanePred) : lanePred;
170+
}
171+
183172
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
184173
SmallVector<Value> shfl(acc.size());
185174
for (unsigned i = 0; i < acc.size(); ++i) {
186175
shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave);
187176
}
188-
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
177+
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred);
189178
}
190179
}
191180

@@ -344,7 +333,8 @@ struct ReduceOpConversion
344333
acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy,
345334
threadIsNeeded);
346335
}
347-
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */);
336+
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */,
337+
threadIsNeeded);
348338
// only the first thread in each sizeInterWarps is writing
349339
Value writeOffset = readOffset;
350340
SmallVector<Value> writePtrs(op.getNumOperands());

lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
// TODO: refactor so that it doesn't fail if Allocation.h
55
// is included after utility.h (due to conflict in `store` macro
66
// and <atomic>
7-
#include "triton/Analysis/Allocation.h"
7+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
8+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
9+
#include "mlir/Transforms/DialectConversion.h"
810

9-
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
1011
//
1112
#include "mlir/IR/TypeUtilities.h"
12-
#include "triton/Analysis/AxisInfo.h"
1313
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
14-
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
15-
#include <set>
14+
#include <iterator>
1615
#include <type_traits>
1716

1817
#define DEBUG_TYPE "ttgpu_to_llvm"
@@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu;
3231
namespace mlir::triton {
3332
class ReduceOp;
3433
class ScanOp;
34+
35+
inline SmallVector<Value>
36+
inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock,
37+
Block *insertionBlock, Block::iterator insertionPoint,
38+
ValueRange combineArgs) {
39+
auto returnOp = combineBlock.getTerminator();
40+
rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint,
41+
combineArgs);
42+
43+
auto results = SmallVector<Value>(returnOp->getOperands());
44+
45+
// Delete the terminator, which is no longer used
46+
rewriter.eraseOp(returnOp);
47+
return results;
48+
}
49+
50+
inline SmallVector<Value> applyCombineOp(Location loc,
51+
ConversionPatternRewriter &rewriter,
52+
Region &combineOp, ValueRange acc,
53+
ValueRange cur, Value pred = {}) {
54+
// Allows for passing an unitialized acc and use cur as the neutral element
55+
if (acc.size() == 0) {
56+
return cur;
57+
}
58+
assert(cur.size() == acc.size());
59+
60+
// Create a new copy of the combine block, and try to speculatively inline it
61+
Block *currentBlock = rewriter.getBlock();
62+
Region &parent = *currentBlock->getParent();
63+
64+
rewriter.cloneRegionBefore(combineOp, parent,
65+
std::next(currentBlock->getIterator()));
66+
Block &newCombine = *currentBlock->getNextNode();
67+
68+
llvm::SmallVector<Value> combineArgs(2 * acc.size());
69+
for (unsigned i = 0; i < acc.size(); ++i) {
70+
combineArgs[i] = acc[i];
71+
combineArgs[acc.size() + i] = cur[i];
72+
}
73+
74+
auto isRegionSpeculatable =
75+
std::all_of(newCombine.begin(), newCombine.end(),
76+
[](auto &op) { return isSpeculatable(&op); });
77+
78+
if (!pred || isRegionSpeculatable) {
79+
// Fast path, region has no side effects so we can unconditionally execute
80+
return inlineCombineBlock(rewriter, newCombine, currentBlock,
81+
rewriter.getInsertionPoint(), combineArgs);
82+
}
83+
84+
// Slow case, create an if to only execute region when pred is true
85+
// #currentBlock
86+
// if (pred) {
87+
// #newCombine
88+
// results = combineOp(cur, acc)
89+
// yield results
90+
// } else {
91+
// yield undef
92+
// }
93+
// #thenBlock
94+
Block *thenBlock =
95+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
96+
97+
auto returnOp = newCombine.getTerminator();
98+
auto results = SmallVector<Value>(returnOp->getOperands());
99+
100+
rewriter.setInsertionPointToEnd(currentBlock);
101+
SmallVector<Value> thenBlockArgs;
102+
thenBlockArgs.reserve(results.size());
103+
for (auto result : results) {
104+
auto ty = result.getType();
105+
auto undef = rewriter.create<LLVM::UndefOp>(loc, ty);
106+
thenBlockArgs.push_back(undef);
107+
thenBlock->addArgument(ty, loc);
108+
}
109+
rewriter.create<cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
110+
thenBlock, thenBlockArgs);
111+
112+
// Split a block after the call.
113+
rewriter.setInsertionPointToEnd(&newCombine);
114+
rewriter.replaceOpWithNewOp<cf::BranchOp>(returnOp, thenBlock, results);
115+
rewriter.setInsertionPointToStart(thenBlock);
116+
return SmallVector<Value>(thenBlock->getArguments());
117+
}
118+
35119
} // namespace mlir::triton
36120

37121
template <typename SourceOp>

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <iterator>
2-
31
#include "ReduceScanCommon.h"
42
#include "mlir/Support/LLVM.h"
53
#include "triton/Analysis/Utility.h"
@@ -16,37 +14,13 @@ using ::mlir::LLVM::linearize;
1614
using ::mlir::triton::gpu::getTotalElemsPerThread;
1715

1816
// apply combine region to acc and cur and accumulate it into acc
19-
// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce.
20-
// Deduplicate
21-
static SmallVector<Value> accumulate(ConversionPatternRewriter &rewriter,
22-
Region &combineOp, ValueRange acc,
23-
ValueRange cur) {
24-
// Allows for passing an unitialized acc and use cur as the neutral element
25-
if (acc.size() == 0) {
26-
return cur;
27-
}
28-
assert(cur.size() == acc.size());
29-
// Create a new copy of the reduce block, and inline it
30-
Block *currentBlock = rewriter.getBlock();
31-
Region &parent = *currentBlock->getParent();
32-
rewriter.cloneRegionBefore(combineOp, &parent.front());
33-
auto &newScan = parent.front();
34-
auto returnOp = dyn_cast<triton::ScanReturnOp>(newScan.getTerminator());
35-
36-
SmallVector<Value> combineArgs(2 * acc.size());
37-
for (unsigned i = 0; i < acc.size(); ++i) {
38-
combineArgs[i] = acc[i];
39-
combineArgs[acc.size() + i] = cur[i];
40-
}
41-
42-
rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(),
43-
combineArgs);
44-
SmallVector<Value> results;
45-
llvm::transform(returnOp.getResult(), std::back_inserter(results),
46-
[&](Value res) { return rewriter.getRemappedValue(res); });
47-
// Delete the terminator, which is no longer used
48-
rewriter.eraseOp(returnOp);
49-
return results;
17+
static SmallVector<Value> accumulate(ScanLoweringHelper &helper,
18+
ConversionPatternRewriter &rewriter,
19+
ValueRange acc, ValueRange cur,
20+
Value pred = {}) {
21+
auto loc = helper.getLoc();
22+
auto &combineOp = helper.getCombineOp();
23+
return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
5024
}
5125

5226
// Scan a contiguous elements within a thread and update `srcValues` in place.
@@ -66,8 +40,8 @@ scanThreadContiguousElements(SmallVector<SmallVector<Value>> &srcValues,
6640
unsigned accIndex = (srcIndex % stride) +
6741
((srcIndex / stride) / scanElementsPerThreads) * stride;
6842

69-
accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex],
70-
srcValues[srcIndex]);
43+
accs[accIndex] =
44+
accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]);
7145
srcValues[srcIndex] = accs[accIndex];
7246
}
7347
}
@@ -95,11 +69,11 @@ static void warpScan(SmallVector<SmallVector<Value>> &srcValues,
9569
for (unsigned j = 0; j < acc.size(); ++j) {
9670
shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride);
9771
}
72+
Value mask = icmp_sge(laneIdAxis, i32_val(i));
9873
SmallVector<Value> tempAcc =
99-
accumulate(rewriter, helper.getCombineOp(), shfl, acc);
100-
Value mask = icmp_slt(laneIdAxis, i32_val(i));
74+
accumulate(helper, rewriter, shfl, acc, mask);
10175
for (unsigned j = 0; j < acc.size(); ++j) {
102-
acc[j] = select(mask, acc[j], tempAcc[j]);
76+
acc[j] = select(mask, tempAcc[j], acc[j]);
10377
}
10478
}
10579
srcValues[srcIndex] = acc;
@@ -164,9 +138,9 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
164138
unsigned elementStride = helper.getAxisElementStride();
165139
unsigned threadStride = helper.getAxisThreadStride();
166140
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
167-
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
168-
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
169-
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
141+
Value maskNotFirstWarp = icmp_ne(warpId, i32_val(0));
142+
Value maskNotFirstLane = icmp_ne(laneIdAxis, i32_val(0));
143+
Value maskNotFirstThread = or_(maskNotFirstWarp, maskNotFirstLane);
170144
struct Accumulator {
171145
SmallVector<Value> acc;
172146
SmallVector<Value> maskedAcc;
@@ -212,42 +186,43 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
212186
accumulator.maskedAcc = partialReduce;
213187
continue;
214188
}
215-
accumulator.acc = accumulate(rewriter, helper.getCombineOp(),
216-
accumulator.acc, partialReduce);
217-
Value mask = icmp_slt(warpId, i32_val(i + 1));
189+
Value mask = icmp_sge(warpId, i32_val(i + 1));
190+
accumulator.acc =
191+
accumulate(helper, rewriter, accumulator.acc, partialReduce, mask);
218192
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
219193
accumulator.maskedAcc[j] =
220-
select(mask, accumulator.maskedAcc[j], accumulator.acc[j]);
194+
select(mask, accumulator.acc[j], accumulator.maskedAcc[j]);
221195
}
222196
}
223-
auto temp = accumulate(rewriter, helper.getCombineOp(),
224-
accumulator.maskedAcc, srcValues[srcIndex]);
197+
198+
Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{};
199+
auto temp = accumulate(helper, rewriter, accumulator.maskedAcc,
200+
srcValues[srcIndex], pred);
225201
if (axisBlockId == 0) {
226202
// For the first warp and first chunk we don't have anything to
227203
// accumulate.
228204
auto val = srcValues[srcIndex];
229205
for (unsigned i = 0; i < helper.getNumOperands(); ++i) {
230-
temp[i] = select(maskFirstWarp, val[i], temp[i]);
206+
temp[i] = select(maskNotFirstWarp, temp[i], val[i]);
231207
}
232208
}
233209
srcValues[srcIndex] = temp;
234210
// Update the rest of the contiguous elements.
235211
SmallVector<Value> lastElement(helper.getNumOperands());
236212
for (unsigned i = 0; i < helper.getNumOperands(); ++i) {
237213
auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride);
238-
lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem);
214+
lastElement[i] = select(maskNotFirstLane, elem, accumulator.maskedAcc[i]);
239215
}
240216
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
217+
pred = axisBlockId == 0 ? maskNotFirstThread : Value{};
241218
auto laneValue = srcValues[srcIndex - i * elementStride];
242-
laneValue =
243-
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
219+
laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred);
244220
if (axisBlockId == 0) {
245221
// For the first warp and first chunk we don't have anything to
246222
// accumulate.
247223
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
248-
laneValue[j] =
249-
select(maskFirstThread,
250-
srcValues[srcIndex - i * elementStride][j], laneValue[j]);
224+
laneValue[j] = select(maskNotFirstThread, laneValue[j],
225+
srcValues[srcIndex - i * elementStride][j]);
251226
}
252227
}
253228
srcValues[srcIndex - i * elementStride] = laneValue;
@@ -300,8 +275,8 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
300275
if (axisBlockId == 0) // First chunk and first block
301276
accumulator = srcValues[srcIndex];
302277
else
303-
srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(),
304-
accumulator, srcValues[srcIndex]);
278+
srcValues[srcIndex] =
279+
accumulate(helper, rewriter, accumulator, srcValues[srcIndex]);
305280
// Update the rest of the contiguous elements.
306281
auto lastElement = srcValues[srcIndex];
307282
if (scanDim > 1) {
@@ -319,8 +294,7 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
319294
}
320295
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
321296
auto laneValue = srcValues[srcIndex - i * elementStride];
322-
laneValue =
323-
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
297+
laneValue = accumulate(helper, rewriter, lastElement, laneValue);
324298
if (axisBlockId == 0) {
325299
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
326300
// For the first warp and first chunk we don't have anything to

0 commit comments

Comments
 (0)