Skip to content

Commit 661f264

Browse files
Merge commit 'b3b9931cb7ed07a6d7a3833c8dcb7b7b519e882f'
2 parents eeb07f7 + b3b9931 commit 661f264

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1358
-650
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8989
mlir::triton::registerConvertWarpSpecializeToLLVM();
9090
mlir::triton::registerConvertTritonGPUToLLVMPass();
9191
mlir::triton::registerConvertNVGPUToLLVMPass();
92+
mlir::triton::registerAllocateSharedMemoryNvPass();
9293
mlir::registerLLVMDIScope();
9394
mlir::triton::gpu::intel::registerTritonAnnotateModulePass();
9495
mlir::triton::gpu::intel::registerTritonIntelGPUPasses();

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,6 @@ class TargetInfoBase {
3838
pred);
3939
}
4040

41-
virtual bool canUseStMatrix(RankedTensorType tensorTy,
42-
ArrayRef<unsigned> repShape,
43-
ArrayRef<unsigned> paddedRepShape,
44-
ArrayRef<unsigned> order,
45-
int swizzleByteSize) const = 0;
46-
47-
virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
48-
Value ptr, Value val) const = 0;
49-
5041
virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val,
5142
int i) const = 0;
5243
virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1212
#include "triton/Dialect/TritonGPU/IR/Types.h"
13+
#include "triton/Tools/GenericSwizzling.h"
1314
#include "triton/Tools/LinearLayout.h"
1415
#include "triton/Tools/StrUtil.h"
1516
#include "llvm/ADT/STLExtras.h"
@@ -321,6 +322,10 @@ namespace mlir {
321322
namespace triton {
322323

323324
namespace gpu {
325+
326+
std::pair<SmallVector<LocalMemOpTile>, SmallVector<LocalMemOpTile>>
327+
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth);
328+
324329
Type getFunctionType(Type resultType, ValueRange operands);
325330

326331
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
@@ -607,10 +612,6 @@ std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp);
607612

608613
std::optional<LLVM::AtomicOrdering> getMemoryOrdering(MemSemantic memOrdering);
609614

610-
bool isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
611-
ArrayRef<int64_t> allocShape,
612-
triton::gpu::SharedEncodingTrait sharedEnc);
613-
614615
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx);
615616

616617
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type);

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace mlir {
99

10+
// Bitwidth of pointers
11+
constexpr int kPtrBitWidth = 64;
12+
1013
template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
1114
SmallVector<T> out;
1215
for (const auto &i : in)
@@ -186,6 +189,7 @@ bool isHostSideDescriptor(Value v);
186189

187190
bool isKernel(FunctionOpInterface funcOp);
188191

192+
unsigned getBitwidth(RankedTensorType ty);
189193
} // namespace triton
190194
} // namespace mlir
191195

include/triton/Tools/GenericSwizzling.h

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,49 @@
44
#include "llvm/ADT/ArrayRef.h"
55
#include "llvm/ADT/SmallVector.h"
66
#include <cstdint>
7+
#include <utility>
78

89
namespace mlir::triton {
910
class LinearLayout;
10-
}
11+
class TargetInfoBase;
12+
} // namespace mlir::triton
1113

1214
namespace mlir::triton::gpu {
13-
LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
14-
int32_t bitwidth);
15+
// Store the lane indices that are used in the contiguous part
16+
// of an operation and in the address part.
17+
// The laneAddr part just represents the indices used in one wavefront
18+
// For now we just represent tiles with full vectorisation, meaning
19+
// ld.shared.b32.v4/st.shared.b32.v4
20+
// ldmatrix.v4 / stmatrix.v4
21+
// ldmatrix.trans.v4 / stmatrix.trans.v4
22+
struct LocalMemOpTile {
23+
// If laneContig.size() < log2(128/bitwidth), we assume that
24+
// the first log2(128/bitwidth) - laneContig.size() bases are registers
25+
llvm::SmallVector<int32_t> laneContig;
26+
// If laneAddr.size() < 3, we assume that the first
27+
// 3 - laneAddr.size() bases are registers
28+
llvm::SmallVector<int32_t> laneAddr;
29+
};
1530

16-
std::pair<int, int> logBankConflicts(const LinearLayout &src,
17-
const LinearLayout &dst,
31+
// Given a set of possible instructions given by
32+
// targetInfo.laneIdTiles(bitwidth) returns the optimal swizzling given these
33+
// instructions and a pair of indices into the ldStTiles that's needed to lower
34+
// this swizzling
35+
std::pair<LinearLayout, std::pair<int32_t, int32_t>>
36+
optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
37+
llvm::ArrayRef<LocalMemOpTile> srcTiles,
38+
llvm::ArrayRef<LocalMemOpTile> dstTiles, int32_t bitwidth);
39+
40+
LinearLayout optimalSwizzlingLdSt(const LinearLayout &src,
41+
const LinearLayout &dst, int32_t bitwidth);
42+
43+
std::pair<int, int> logBankConflictsLdSt(const LinearLayout &src,
44+
const LinearLayout &dst,
45+
const LinearLayout &smem,
46+
int32_t bitwidth);
47+
48+
std::pair<int, int> logBankConflicts(llvm::ArrayRef<int32_t> tileSrc,
49+
llvm::ArrayRef<int32_t> tileDst,
1850
const LinearLayout &smem,
1951
int32_t bitwidth);
2052
} // namespace mlir::triton::gpu

include/triton/Tools/LayoutUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128128
std::pair<int64_t, ColumnAction>
129-
actionAdditiveStrides(const LinearLayout &layout, uint64_t maskSpanOffsets);
129+
actionAdditiveStrides(const LinearLayout &layout, const LinearLayout addrLayout,
130+
uint64_t maskSpanOffsets);
130131

131132
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
132133
// the same broadcasting as layout

lib/Analysis/Allocation.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ namespace mlir {
2929
//===----------------------------------------------------------------------===//
3030
namespace triton {
3131

32-
// Bitwidth of pointers
33-
constexpr int kPtrBitWidth = 64;
3432
// Max shmem LDS/STS instruction in bits
3533
constexpr int kMaxShmemVecBitLength = 128;
3634

37-
static unsigned getBitwidth(RankedTensorType ty) {
38-
auto isPtr = isa<PointerType>(ty.getElementType());
39-
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
35+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
36+
RankedTensorType dstTy) {
37+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
38+
return getNumScratchElements(scratchConfig.paddedRepShape);
4039
}
4140

4241
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
@@ -47,17 +46,11 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4746
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
4847
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
4948
auto bitwidth = getBitwidth(srcTy);
50-
auto smem = gpu::optimalSwizzling(srcLayout, dstLayout, bitwidth);
49+
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
5150
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
5251
return smem.getTotalOutDimSize() / reps;
5352
}
5453

55-
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56-
RankedTensorType dstTy) {
57-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
58-
return getNumScratchElements(scratchConfig.paddedRepShape);
59-
}
60-
6154
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
6255
RankedTensorType dstTy) {
6356
Attribute srcLayout = srcTy.getEncoding();
@@ -215,10 +208,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
215208
auto dstTy = cvtLayout.getType();
216209
if (!cvtNeedsSharedMemory(srcTy, dstTy))
217210
return 0;
218-
// Pesimistically take the max. We will revisit later
219-
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
220-
getNumScratchElemsPaddedCvt(srcTy, dstTy));
221-
211+
// The generic pass uses swizzling
212+
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
222213
return elems * getBitwidth(srcTy) / 8;
223214
}
224215
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
6363
} else if (llvm::is_contained(dims, kWarp)) {
6464
// Case 2: Transfer between values in the same CTA, in which case we move
6565
// values through shared memory.
66-
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
66+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
67+
return success();
6768
} else if (llvm::is_contained(dims, kLane)) {
6869
// Case 3. Transfer between values in the same warp, in which case we try
6970
// to move values using warp shuffles, though if the pattern is
@@ -74,7 +75,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
7475
// TODO: Since data is only transferred within a warp over shared memory,
7576
// we should use `bar.warp.sync` instead of `barrier`, which will improve
7677
// latency when warps issue barriers on different cycles.
77-
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
78+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
79+
return success();
7880
} else if (llvm::is_contained(dims, kRegister)) {
7981
// Case 4. Transfer between values in the same thread, in which case we
8082
// simply reorder the elements of adaptor.getSrc().
@@ -169,7 +171,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
169171
// At this point we have a type that's at least 8-bit
170172
// and we don't have broadcasting in the registers
171173
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
172-
auto smem = optimalSwizzling(srcLayout, dstLayout, bitwidth);
174+
auto smem = optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
173175

174176
// Extract reps from smem
175177
auto kReg = str_attr("register");
@@ -201,9 +203,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
201203

202204
assert(permutedInVals.size() == tileSize * nReps);
203205
SmallVector<Value> outVals;
204-
auto noPaddingOffset = [](Value v) { return v; };
205206
auto affineOffset = b.i32_val(0);
206207
auto maskSpanAffineOffset = 0;
208+
auto noPaddingOffset = [](Value v) { return v; };
207209
for (int i = 0; i < nReps; ++i) {
208210
if (i > 0)
209211
b.barrier();
@@ -227,20 +229,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
227229
return outVals;
228230
}
229231

230-
LogicalResult
231-
transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
232-
ConversionPatternRewriter &rewriter) const {
233-
// Fallback for now to standard lowering if it can use stmatrix
234-
auto scratchConfig =
235-
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
236-
bool isStMatrix = targetInfo.canUseStMatrix(
237-
op.getSrc().getType(), scratchConfig.repShape,
238-
scratchConfig.paddedRepShape, scratchConfig.order,
239-
/*swizzleByteSize=*/0);
240-
if (isStMatrix) {
241-
return failure();
242-
}
243-
232+
void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
233+
ConversionPatternRewriter &rewriter) const {
244234
auto loc = op.getLoc();
245235
auto *ctx = op.getContext();
246236
auto srcTy = op.getSrc().getType();
@@ -268,28 +258,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
268258
Value result =
269259
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);
270260
rewriter.replaceOp(op, result);
271-
return success();
272-
}
273-
274-
LogicalResult transferWithinBlock(ConvertLayoutOp op,
275-
const LinearLayout &srcLayout,
276-
const LinearLayout &dstLayout,
277-
OpAdaptor adaptor,
278-
ConversionPatternRewriter &rewriter) const {
279-
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
280-
281-
// Try to use swizzling to implement the conversion
282-
// HACK Remove once AMD tests pass for the swizzling path
283-
if (targetInfo.isCuda() && succeeded(transferWithinBlockSwizzling(
284-
op, adaptor.getSrc(), rewriter))) {
285-
return success();
286-
}
287-
288-
Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
289-
getTypeConverter(), rewriter);
290-
291-
rewriter.replaceOp(op, result);
292-
return success();
293261
}
294262

295263
// Use warp shuffles to implement a layout conversion where data only needs to

0 commit comments

Comments
 (0)