Skip to content

Commit d651a84

Browse files
More code refactoring and sync from upstream (#2409)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent cbecf26 commit d651a84

File tree

6 files changed

+186
-150
lines changed

6 files changed

+186
-150
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ class TritonGPUToLLVMPipelineManager {
255255

256256
intel::populateSPMDOpToLLVMPattern(typeConverter, patterns, targetInfo,
257257
benefit);
258+
mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns,
259+
targetInfo, benefit);
258260
// TODO(thomas): this should probably be done in a separate step to not
259261
// interfere with our own lowering of arith ops. Add arith/math's patterns
260262
// to help convert scalar expression to LLVM.

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using namespace mlir::triton;
1010
using ::mlir::LLVM::delinearize;
1111
using ::mlir::LLVM::linearize;
1212
using ::mlir::triton::gpu::getOrder;
13+
using ::mlir::triton::gpu::getThreadOrder;
1314
using ::mlir::triton::gpu::getTotalElemsPerThread;
1415

1516
namespace {
@@ -80,35 +81,14 @@ struct ReduceOpConversion
8081

8182
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
8283
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
83-
bool isFirst) const {
84-
if (isFirst) {
85-
acc = SmallVector<Value>(cur.begin(), cur.end());
86-
return;
87-
}
88-
89-
// Create a new copy of the reduce block, and inline it
90-
Block *currentBlock = rewriter.getBlock();
91-
Region &parent = *currentBlock->getParent();
92-
rewriter.cloneRegionBefore(combineOp, &parent.front());
93-
auto &newReduce = parent.front();
94-
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());
95-
96-
llvm::SmallVector<Value> combineArgs(2 * acc.size());
97-
for (unsigned i = 0; i < acc.size(); ++i) {
98-
combineArgs[i] = acc[i];
99-
combineArgs[acc.size() + i] = cur[i];
84+
Value pred = {}) const {
85+
auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
86+
if (acc.size() < results.size()) {
87+
acc.resize(results.size());
10088
}
101-
102-
rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
103-
combineArgs);
104-
105-
auto results = returnOp.getResult();
10689
for (unsigned i = 0; i < acc.size(); ++i) {
10790
acc[i] = results[i];
10891
}
109-
110-
// Delete the terminator, which is no longer used
111-
rewriter.eraseOp(returnOp);
11292
}
11393

11494
SmallVector<SmallVector<Value>>
@@ -165,8 +145,7 @@ struct ReduceOpConversion
165145
SmallVector<unsigned> key = offsets[i];
166146
key[op.getAxis()] = 0;
167147
bool isFirst = accs.find(key) == accs.end();
168-
accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i],
169-
isFirst);
148+
accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]);
170149
if (isFirst)
171150
indices[key] = srcIndices[i];
172151
}
@@ -182,12 +161,23 @@ struct ReduceOpConversion
182161
numLaneToReduce, interleave);
183162
if (success)
184163
return;
164+
165+
auto mod = op->getParentOfType<ModuleOp>();
166+
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
167+
if (iWarpSize > numLaneToReduce) {
168+
Value threadId = getThreadId(rewriter, loc);
169+
Value warpSize = i32_val(iWarpSize);
170+
Value laneId = urem(threadId, warpSize);
171+
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
172+
pred = pred ? and_(pred, lanePred) : lanePred;
173+
}
174+
185175
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
186176
SmallVector<Value> shfl(acc.size());
187177
for (unsigned i = 0; i < acc.size(); ++i) {
188178
shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave);
189179
}
190-
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, false);
180+
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred);
191181
}
192182
}
193183

@@ -242,7 +232,7 @@ struct ReduceOpConversion
242232
ConversionPatternRewriter &rewriter) const {
243233
auto srcLayout = helper.getSrcLayout();
244234
auto srcShape = helper.getSrcShape();
245-
auto order = getOrder(srcLayout);
235+
auto order = triton::gpu::getWarpOrder(srcLayout);
246236
SmallVector<Value> multiDimWarpId;
247237

248238
// 2x2 warps with slice dim = 0, warpId = 2 ends up writing at the same
@@ -251,7 +241,7 @@ struct ReduceOpConversion
251241
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
252242
auto parentLayout = sliceLayout.getParent();
253243
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout);
254-
auto parentOrder = triton::gpu::getOrder(parentLayout);
244+
auto parentOrder = triton::gpu::getWarpOrder(parentLayout);
255245
multiDimWarpId =
256246
delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
257247
multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim());
@@ -284,7 +274,7 @@ struct ReduceOpConversion
284274

285275
auto threadsPerWarp =
286276
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
287-
auto order = getOrder(srcLayout);
277+
auto order = getThreadOrder(srcLayout);
288278
SmallVector<Value> multiDimLaneId =
289279
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
290280
Value laneIdAxis = multiDimLaneId[axis];

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceScanCommon.h

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44
// TODO: refactor so that it doesn't fail if Allocation.h
55
// is included after utility.h (due to conflict in `store` macro
66
// and <atomic>
7-
#include "triton/Analysis/Allocation.h"
7+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
8+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
9+
#include "mlir/Transforms/DialectConversion.h"
810

911
//
1012
#include "Utility.h"
1113
#include "mlir/IR/TypeUtilities.h"
12-
13-
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
14-
#include "triton/Analysis/AxisInfo.h"
15-
#include <set>
1614
#include <type_traits>
1715

1816
#define DEBUG_TYPE "ttgpu_to_llvm"
@@ -26,11 +24,95 @@ using ::mlir::triton::gpu::BlockedEncodingAttr;
2624
using ::mlir::triton::gpu::CTALayoutAttr;
2725
using ::mlir::triton::gpu::DotOperandEncodingAttr;
2826
using ::mlir::triton::gpu::SliceEncodingAttr;
29-
using ::mlir::triton::gpu::intel::DpasEncodingAttr;
3027

3128
namespace mlir::triton {
3229
class ReduceOp;
3330
class ScanOp;
31+
32+
inline SmallVector<Value>
33+
inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock,
34+
Block *insertionBlock, Block::iterator insertionPoint,
35+
ValueRange combineArgs) {
36+
auto returnOp = combineBlock.getTerminator();
37+
rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint,
38+
combineArgs);
39+
40+
auto results = SmallVector<Value>(returnOp->getOperands());
41+
42+
// Delete the terminator, which is no longer used
43+
rewriter.eraseOp(returnOp);
44+
return results;
45+
}
46+
47+
inline SmallVector<Value> applyCombineOp(Location loc,
48+
ConversionPatternRewriter &rewriter,
49+
Region &combineOp, ValueRange acc,
50+
ValueRange cur, Value pred = {}) {
51+
// Allows for passing an unitialized acc and use cur as the neutral element
52+
if (acc.size() == 0) {
53+
return cur;
54+
}
55+
assert(cur.size() == acc.size());
56+
57+
// Create a new copy of the combine block, and try to speculatively inline it
58+
Block *currentBlock = rewriter.getBlock();
59+
Region &parent = *currentBlock->getParent();
60+
61+
rewriter.cloneRegionBefore(combineOp, parent,
62+
std::next(currentBlock->getIterator()));
63+
Block &newCombine = *currentBlock->getNextNode();
64+
65+
llvm::SmallVector<Value> combineArgs(2 * acc.size());
66+
for (unsigned i = 0; i < acc.size(); ++i) {
67+
combineArgs[i] = acc[i];
68+
combineArgs[acc.size() + i] = cur[i];
69+
}
70+
71+
auto isRegionSpeculatable =
72+
std::all_of(newCombine.begin(), newCombine.end(),
73+
[](auto &op) { return isSpeculatable(&op); });
74+
75+
if (!pred || isRegionSpeculatable) {
76+
// Fast path, region has no side effects so we can unconditionally execute
77+
return inlineCombineBlock(rewriter, newCombine, currentBlock,
78+
rewriter.getInsertionPoint(), combineArgs);
79+
}
80+
81+
// Slow case, create an if to only execute region when pred is true
82+
// #currentBlock
83+
// if (pred) {
84+
// #newCombine
85+
// results = combineOp(cur, acc)
86+
// yield results
87+
// } else {
88+
// yield undef
89+
// }
90+
// #thenBlock
91+
Block *thenBlock =
92+
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
93+
94+
auto returnOp = newCombine.getTerminator();
95+
auto results = SmallVector<Value>(returnOp->getOperands());
96+
97+
rewriter.setInsertionPointToEnd(currentBlock);
98+
SmallVector<Value> thenBlockArgs;
99+
thenBlockArgs.reserve(results.size());
100+
for (auto result : results) {
101+
auto ty = result.getType();
102+
auto undef = rewriter.create<LLVM::UndefOp>(loc, ty);
103+
thenBlockArgs.push_back(undef);
104+
thenBlock->addArgument(ty, loc);
105+
}
106+
rewriter.create<cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
107+
thenBlock, thenBlockArgs);
108+
109+
// Split a block after the call.
110+
rewriter.setInsertionPointToEnd(&newCombine);
111+
rewriter.replaceOpWithNewOp<cf::BranchOp>(returnOp, thenBlock, results);
112+
rewriter.setInsertionPointToStart(thenBlock);
113+
return SmallVector<Value>(thenBlock->getArguments());
114+
}
115+
34116
} // namespace mlir::triton
35117

36118
template <typename SourceOp>
@@ -66,12 +148,13 @@ class ConvertTritonIntelGPUReduceScanToLLVMPattern
66148
});
67149
// Assign base index to each operand in their order in indices
68150
std::map<unsigned, Value> indexToBase;
69-
indexToBase[indices[0]] = LLVM::intel::getSharedMemoryBase(
70-
loc, rewriter, targetInfo, op.getOperation());
151+
auto basePtr = LLVM::intel::getSharedMemoryBase(loc, rewriter, targetInfo,
152+
op.getOperation());
153+
indexToBase[indices[0]] = basePtr;
71154
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
72-
indexToBase[indices[i]] = gep(
73-
ptr_ty(rewriter.getContext(), 3), getElementType(op, indices[i - 1]),
74-
indexToBase[indices[i - 1]], i32_val(elems));
155+
indexToBase[indices[i]] =
156+
gep(basePtr.getType(), getElementType(op, indices[i - 1]),
157+
indexToBase[indices[i - 1]], i32_val(elems));
75158
}
76159
// smemBases[k] is the base pointer for the k-th operand
77160
SmallVector<Value> smemBases(op.getNumOperands());

third_party/intel/lib/TritonIntelGPUToLLVM/SPMDOpToLLVM.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,6 @@ namespace {
66
using namespace mlir;
77
using namespace mlir::triton;
88

9-
struct GetProgramIdOpConversion
10-
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
11-
explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter,
12-
const TargetInfoBase &targetInfo,
13-
PatternBenefit benefit = 1)
14-
: ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp>(typeConverter,
15-
benefit),
16-
targetInfo(targetInfo) {}
17-
using ConvertTritonGPUOpToLLVMPattern<
18-
triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern;
19-
20-
LogicalResult
21-
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
22-
ConversionPatternRewriter &rewriter) const override {
23-
Value programId = targetInfo.programId(rewriter, op->getLoc(),
24-
op->getParentOfType<ModuleOp>(),
25-
op.getAxisAsInt());
26-
rewriter.replaceOp(op, programId);
27-
return success();
28-
}
29-
30-
private:
31-
const TargetInfoBase &targetInfo;
32-
};
33-
349
struct GetNumProgramsOpConversion
3510
: public ConvertTritonGPUOpToLLVMPattern<triton::GetNumProgramsOp> {
3611
using ConvertTritonGPUOpToLLVMPattern<
@@ -39,26 +14,22 @@ struct GetNumProgramsOpConversion
3914
LogicalResult
4015
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,
4116
ConversionPatternRewriter &rewriter) const override {
17+
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
18+
mlir::gpu::Dimension::y,
19+
mlir::gpu::Dimension::z};
4220
Location loc = op->getLoc();
4321
assert(op.getAxisAsInt() < 3);
44-
4522
Value blockId =
4623
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxisAsInt()]);
4724
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, i32_ty, blockId);
48-
4925
return success();
5026
}
51-
52-
static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x,
53-
mlir::gpu::Dimension::y,
54-
mlir::gpu::Dimension::z};
5527
};
5628

5729
} // namespace
5830

5931
void mlir::triton::intel::populateSPMDOpToLLVMPattern(
6032
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
6133
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
62-
patterns.add<GetProgramIdOpConversion>(typeConverter, targetInfo, benefit);
6334
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
6435
}

0 commit comments

Comments
 (0)