Skip to content

Commit 2a4b054

Browse files
Merge commit 'e7ec3fee2983a447a31812da2e56acdc0d7ee6af'
2 parents 2a062d3 + e7ec3fe commit 2a4b054

File tree

17 files changed

+473
-200
lines changed

17 files changed

+473
-200
lines changed

.github/workflows/documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525

2626
- name: Install dependent packages
2727
run: |
28-
sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion
28+
sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion llnl-hatchet
2929
3030
#- name: Fetch dependent branches
3131
# run: |

docs/conf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,6 @@ def documenter(app, obj, parent):
159159
'examples_dirs': '../python/tutorials/',
160160
'gallery_dirs': 'getting-started/tutorials',
161161
'filename_pattern': '',
162-
# TODO: Re-enable the grouped-gemm tutorial. It currently hits this
163-
# assertion:
164-
# https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp#L127
165162
'ignore_pattern': r'(__init__\.py|11.*.py)',
166163
'within_subsection_order': FileNameSortKey,
167164
'reference_url': {

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
#include "ReduceScanCommon.h"
2-
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
32
#include "mlir/Support/LLVM.h"
43
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
54
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
6-
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
7-
#include <vector>
85

96
using namespace mlir;
107
using namespace mlir::triton;
@@ -80,36 +77,16 @@ struct ReduceOpConversion
8077
private:
8178
const TargetInfoBase &targetInfo;
8279

83-
void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
84-
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
85-
if (isFirst) {
86-
acc = SmallVector<Value>(cur.begin(), cur.end());
87-
return;
80+
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
81+
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
82+
Value pred = {}) const {
83+
auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
84+
if (acc.size() < results.size()) {
85+
acc.resize(results.size());
8886
}
89-
90-
// Create a new copy of the reduce block, and inline it
91-
Block *currentBlock = rewriter.getBlock();
92-
Region &parent = *currentBlock->getParent();
93-
rewriter.cloneRegionBefore(combineOp, &parent.front());
94-
auto &newReduce = parent.front();
95-
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());
96-
97-
llvm::SmallVector<Value> combineArgs(2 * acc.size());
98-
for (unsigned i = 0; i < acc.size(); ++i) {
99-
combineArgs[i] = acc[i];
100-
combineArgs[acc.size() + i] = cur[i];
101-
}
102-
103-
rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
104-
combineArgs);
105-
106-
auto results = returnOp.getResult();
10787
for (unsigned i = 0; i < acc.size(); ++i) {
10888
acc[i] = results[i];
10989
}
110-
111-
// Delete the terminator, which is no longer used
112-
rewriter.eraseOp(returnOp);
11390
}
11491

11592
SmallVector<SmallVector<Value>>
@@ -165,7 +142,7 @@ struct ReduceOpConversion
165142
SmallVector<unsigned> key = offsets[i];
166143
key[op.getAxis()] = 0;
167144
bool isFirst = accs.find(key) == accs.end();
168-
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
145+
accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]);
169146
if (isFirst)
170147
indices[key] = srcIndices[i];
171148
}
@@ -175,17 +152,29 @@ struct ReduceOpConversion
175152
// region and the accumulator values as source.
176153
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
177154
SmallVector<Value> &acc, triton::ReduceOp op,
178-
unsigned numLaneToReduce, unsigned interleave) const {
155+
unsigned numLaneToReduce, unsigned interleave,
156+
Value pred = {}) const {
179157
auto success = targetInfo.warpReduce(rewriter, loc, acc, op,
180158
numLaneToReduce, interleave);
181159
if (success)
182160
return;
161+
162+
auto mod = op->getParentOfType<ModuleOp>();
163+
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
164+
if (iWarpSize > numLaneToReduce) {
165+
Value threadId = getThreadId(rewriter, loc);
166+
Value warpSize = i32_val(iWarpSize);
167+
Value laneId = urem(threadId, warpSize);
168+
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
169+
pred = pred ? and_(pred, lanePred) : lanePred;
170+
}
171+
183172
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
184173
SmallVector<Value> shfl(acc.size());
185174
for (unsigned i = 0; i < acc.size(); ++i) {
186175
shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave);
187176
}
188-
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
177+
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred);
189178
}
190179
}
191180

@@ -344,7 +333,8 @@ struct ReduceOpConversion
344333
acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy,
345334
threadIsNeeded);
346335
}
347-
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */);
336+
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */,
337+
threadIsNeeded);
348338
// only the first thread in each sizeInterWarps is writing
349339
Value writeOffset = readOffset;
350340
SmallVector<Value> writePtrs(op.getNumOperands());

lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
// TODO: refactor so that it doesn't fail if Allocation.h
55
// is included after utility.h (due to conflict in `store` macro
66
// and <atomic>
7-
#include "triton/Analysis/Allocation.h"
7+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
8+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
9+
#include "mlir/Transforms/DialectConversion.h"
810

9-
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
1011
//
1112
#include "mlir/IR/TypeUtilities.h"
12-
#include "triton/Analysis/AxisInfo.h"
1313
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
14-
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
15-
#include <set>
14+
#include <iterator>
1615
#include <type_traits>
1716

1817
#define DEBUG_TYPE "ttgpu_to_llvm"
@@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu;
3231
namespace mlir::triton {
3332
class ReduceOp;
3433
class ScanOp;
34+
35+
inline SmallVector<Value>
36+
inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock,
37+
Block *insertionBlock, Block::iterator insertionPoint,
38+
ValueRange combineArgs) {
39+
auto returnOp = combineBlock.getTerminator();
40+
rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint,
41+
combineArgs);
42+
43+
auto results = SmallVector<Value>(returnOp->getOperands());
44+
45+
// Delete the terminator, which is no longer used
46+
rewriter.eraseOp(returnOp);
47+
return results;
48+
}
49+
50+
inline SmallVector<Value> applyCombineOp(Location loc,
51+
ConversionPatternRewriter &rewriter,
52+
Region &combineOp, ValueRange acc,
53+
ValueRange cur, Value pred = {}) {
54+
// Allows for passing an unitialized acc and use cur as the neutral element
55+
if (acc.size() == 0) {
56+
return cur;
57+
}
58+
assert(cur.size() == acc.size());
59+
60+
// Create a new copy of the combine block, and try to speculatively inline it
61+
Block *currentBlock = rewriter.getBlock();
62+
Region &parent = *currentBlock->getParent();
63+
64+
rewriter.cloneRegionBefore(combineOp, parent,
65+
std::next(currentBlock->getIterator()));
66+
Block &newCombine = *currentBlock->getNextNode();
67+
68+
llvm::SmallVector<Value> combineArgs(2 * acc.size());
69+
for (unsigned i = 0; i < acc.size(); ++i) {
70+
combineArgs[i] = acc[i];
71+
combineArgs[acc.size() + i] = cur[i];
72+
}
73+
74+
auto isRegionSpeculatable =
75+
std::all_of(newCombine.begin(), newCombine.end(),
76+
[](auto &op) { return isSpeculatable(&op); });
77+
78+
if (!pred || isRegionSpeculatable) {
79+
// Fast path, region has no side effects so we can unconditionally execute
80+
return inlineCombineBlock(rewriter, newCombine, currentBlock,
81+
rewriter.getInsertionPoint(), combineArgs);
82+
}
83+
84+
// Slow case, create an if to only execute region when pred is true
85+
// #currentBlock
86+
// if (pred) {
87+
// #newCombine
88+
// results = combineOp(cur, acc)
89+
// yield results
90+
// } else {
91+
// yield undef
92+
// }
93+
// #thenBlock
94+
Block *thenBlock =
95+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
96+
97+
auto returnOp = newCombine.getTerminator();
98+
auto results = SmallVector<Value>(returnOp->getOperands());
99+
100+
rewriter.setInsertionPointToEnd(currentBlock);
101+
SmallVector<Value> thenBlockArgs;
102+
thenBlockArgs.reserve(results.size());
103+
for (auto result : results) {
104+
auto ty = result.getType();
105+
auto undef = rewriter.create<LLVM::UndefOp>(loc, ty);
106+
thenBlockArgs.push_back(undef);
107+
thenBlock->addArgument(ty, loc);
108+
}
109+
rewriter.create<cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
110+
thenBlock, thenBlockArgs);
111+
112+
// Split a block after the call.
113+
rewriter.setInsertionPointToEnd(&newCombine);
114+
rewriter.replaceOpWithNewOp<cf::BranchOp>(returnOp, thenBlock, results);
115+
rewriter.setInsertionPointToStart(thenBlock);
116+
return SmallVector<Value>(thenBlock->getArguments());
117+
}
118+
35119
} // namespace mlir::triton
36120

37121
template <typename SourceOp>

0 commit comments

Comments
 (0)