Skip to content

Commit b78022a

Browse files
lezcanoapgoucher
andauthored
[BACKEND] Implement generic swizzling when lowering convert_layout (#6982)
We implement a generic swizzling algorithm by @apgoucher that, given two linear layouts, finds the optimal shared memory layout that maximises read/write vectorisation and, provided that, minimises bank conflicts. We also implement an algorithm to find the minimum tile size necessary to perform the `convert_layout` given the restrictions above, and we use it to perform the `convert_layout` iteratively. This PR does not yet implement a lowering to ldmatrix/stmatrix, we'll do that in a future PR. --------- Co-authored-by: Adam P. Goucher <[email protected]>
1 parent 9d11c09 commit b78022a

File tree

20 files changed

+1072
-167
lines changed

20 files changed

+1072
-167
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class TargetInfoBase {
9696

9797
virtual bool supportLdMatrix() const { return false; }
9898
virtual bool supportStMatrix() const { return false; }
99+
virtual bool isCuda() const { return false; }
99100

100101
// Annotate target specific information to local store operations during
101102
// lowering to LLVM.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TRITON_GENERIC_SWIZZLING_H
2+
#define TRITON_GENERIC_SWIZZLING_H
3+
4+
#include "llvm/ADT/ArrayRef.h"
5+
#include "llvm/ADT/SmallVector.h"
6+
#include <cstdint>
7+
8+
namespace mlir::triton {
9+
class LinearLayout;
10+
}
11+
12+
namespace mlir::triton::gpu {
13+
LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
14+
int32_t bitwidth);
15+
16+
std::pair<int, int> logBankConflicts(const LinearLayout &src,
17+
const LinearLayout &dst,
18+
const LinearLayout &smem,
19+
int32_t bitwidth);
20+
} // namespace mlir::triton::gpu
21+
22+
#endif // TRITON_GENERIC_SWIZZLING_H

include/triton/Tools/LayoutUtils.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,18 @@ LinearLayout zerosLike(const LinearLayout &layout);
116116
// For a layout A with A.hasInDim(kReg), find a permutation of registers action
117117
// such that action.apply(A) may be divisible by B
118118
// It's not always true that the action returned by this function will
119-
// allow us to divideLeft, but it is true that if it if there exists one, it is
120-
// the one returned by this function.
121-
std::optional<ColumnAction> regPermForDivideLeft(const LinearLayout &A,
122-
const LinearLayout &B);
119+
// allow us to divideLeft (resp. divideRight), but it is true that if it if
120+
// there exists one, it is the one returned by this function.
121+
std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
122+
const LinearLayout &B, bool left);
123123

124124
// For a layout A with A.hasInDim(kReg), find a permutation of registers action
125125
// such that action.apply(A) has the broadcasted registers removed
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128+
std::pair<int64_t, ColumnAction>
129+
actionAdditiveStrides(const LinearLayout &layout);
130+
128131
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
129132
// the same broadcasting as layout
130133
SmallVector<Value> broadcastAs(const SmallVector<Value> &values,

include/triton/Tools/LinearLayout.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ class LinearLayout {
453453
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
454454
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
455455

456+
// Relevant for reshaping
457+
SmallVector<std::pair<StringAttr, int32_t>> getOutDims() const {
458+
return to_vector(outDims);
459+
}
460+
456461
// Gets the position that this outDim occupies in getOutDimNames(). Asserts
457462
// if the dim is not present.
458463
int32_t getOutDimIndex(StringAttr outDim) const;
@@ -620,6 +625,7 @@ class LinearLayout {
620625

621626
// Compute a C such that A = B * C if it exists.
622627
// In other words, C = B^{-1} * A.
628+
// For divideRight, we compute A = C * B, that is, C = A * B^{-1}.
623629
// Note that such a C exists iff (every pair of input/output dim of) A is
624630
// of the form
625631
// [[B, 0],
@@ -633,6 +639,8 @@ class LinearLayout {
633639
// same dimensions as A ensures that C is well-defined.
634640
friend std::optional<LinearLayout> divideLeft(const LinearLayout &A,
635641
const LinearLayout &B);
642+
friend std::optional<LinearLayout> divideRight(const LinearLayout &A,
643+
const LinearLayout &B);
636644

637645
// Returns true if this layout acts trivially (as the identity) on the given
638646
// dimensions. This means that it's the identity on those dimensions, and it
@@ -798,9 +806,10 @@ class ColumnAction {
798806
SmallVector<size_t> action;
799807
StringAttr inDim;
800808
size_t inSizeLog2;
801-
bool isIdentity;
809+
bool isIdentity = true;
802810

803811
public:
812+
ColumnAction() = default;
804813
ColumnAction(ArrayRef<size_t> action, StringAttr inDim, size_t inSizeLog2)
805814
: action(action), inDim(inDim), inSizeLog2(inSizeLog2) {
806815
auto it = llvm::max_element(action);

lib/Analysis/Allocation.cpp

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "triton/Dialect/Triton/IR/Utility.h"
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1212
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
13+
#include "triton/Tools/GenericSwizzling.h"
14+
#include "triton/Tools/LayoutUtils.h"
1315
#include "llvm/ADT/SmallVector.h"
1416
#include "llvm/Support/Debug.h"
1517
#include "llvm/Support/raw_ostream.h"
@@ -32,6 +34,30 @@ constexpr int kPtrBitWidth = 64;
3234
// Max shmem LDS/STS instruction in bits
3335
constexpr int kMaxShmemVecBitLength = 128;
3436

37+
static unsigned getBitwidth(RankedTensorType ty) {
38+
auto isPtr = isa<PointerType>(ty.getElementType());
39+
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
40+
}
41+
42+
static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
43+
RankedTensorType dstTy) {
44+
auto *ctx = srcTy.getContext();
45+
auto srcLayout = gpu::toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
46+
auto dstLayout = gpu::toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
47+
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
48+
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
49+
auto bitwidth = getBitwidth(srcTy);
50+
auto smem = gpu::optimalSwizzling(srcLayout, dstLayout, bitwidth);
51+
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
52+
return smem.getTotalOutDimSize() / reps;
53+
}
54+
55+
static unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56+
RankedTensorType dstTy) {
57+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
58+
return getNumScratchElements(scratchConfig.paddedRepShape);
59+
}
60+
3561
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
3662
RankedTensorType dstTy) {
3763
Attribute srcLayout = srcTy.getEncoding();
@@ -135,12 +161,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
135161
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
136162
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
137163
// is the max vectorisation
138-
auto inBitWidth = isa<PointerType>(srcTy.getElementType())
139-
? kPtrBitWidth
140-
: srcTy.getElementTypeBitWidth();
141-
auto outBitWidth = isa<PointerType>(dstTy.getElementType())
142-
? kPtrBitWidth
143-
: dstTy.getElementTypeBitWidth();
164+
auto inBitWidth = getBitwidth(srcTy);
165+
auto outBitWidth = getBitwidth(dstTy);
144166
scratchConfig.inVec =
145167
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
146168
scratchConfig.outVec =
@@ -174,27 +196,18 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
174196
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
175197
op->getParentOfType<ModuleOp>());
176198
return std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
177-
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
199+
getBitwidth(dstTy) / 8;
178200
}
179201
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
180202
auto srcTy = cvtLayout.getSrc().getType();
181203
auto dstTy = cvtLayout.getType();
182-
auto srcEncoding = srcTy.getEncoding();
183-
auto dstEncoding = dstTy.getEncoding();
184-
if (mlir::isa<gpu::SharedEncodingTrait>(srcEncoding) ||
185-
mlir::isa<gpu::SharedEncodingTrait>(dstEncoding)) {
186-
// Conversions from/to shared memory do not need scratch memory.
204+
if (!cvtNeedsSharedMemory(srcTy, dstTy))
187205
return 0;
188-
}
189-
// ConvertLayoutOp with both input/output non-shared_layout
190-
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
191-
// also possible to realize it with other approaches in restricted
192-
// conditions, such as warp-shuffle
193-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
194-
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
195-
return isa<PointerType>(srcTy.getElementType())
196-
? elems * kPtrBitWidth / 8
197-
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
206+
// Pesimistically take the max. We will revisit later
207+
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
208+
getNumScratchElemsPaddedCvt(srcTy, dstTy));
209+
210+
return elems * getBitwidth(srcTy) / 8;
198211
}
199212
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
200213
auto value = op->getOperand(0);

0 commit comments

Comments
 (0)