Skip to content

Commit b3b9931

Browse files
authored
[LAYOUTS] Implement generalized swizzling for convert_layout (#7565)
We generalize the swizzling algorithm to consider the instructions `ldmatrix/stmatrix` and their transpose versions. To do this, we now require having a dedicated allocator for nvidia, as the required shmem for a convert_layout will now depend on the instructions we can emit. After cleaning up the stmatrix path from the common `convert_layout` lowering, it became clear that we always take the swizzling path. I changed the allocator to reflect this, and I had to change a ton of tests that used it and now don't require padding. We also implement an improved lowering for the indexing of `ldmatrix/stmatrix` following the optimisations from `ld.shared/st.shared`.
1 parent ea4bdaf commit b3b9931

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1382
-711
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6868
mlir::triton::registerConvertWarpSpecializeToLLVM();
6969
mlir::triton::registerConvertTritonGPUToLLVMPass();
7070
mlir::triton::registerConvertNVGPUToLLVMPass();
71+
mlir::triton::registerAllocateSharedMemoryNvPass();
7172
mlir::registerLLVMDIScope();
7273

7374
// TritonAMDGPUToLLVM passes

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,6 @@ class TargetInfoBase {
3838
pred);
3939
}
4040

41-
virtual bool canUseStMatrix(RankedTensorType tensorTy,
42-
ArrayRef<unsigned> repShape,
43-
ArrayRef<unsigned> paddedRepShape,
44-
ArrayRef<unsigned> order,
45-
int swizzleByteSize) const = 0;
46-
47-
virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
48-
Value ptr, Value val) const = 0;
49-
5041
virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val,
5142
int i) const = 0;
5243
virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1212
#include "triton/Dialect/TritonGPU/IR/Types.h"
13+
#include "triton/Tools/GenericSwizzling.h"
1314
#include "triton/Tools/LinearLayout.h"
1415
#include "triton/Tools/StrUtil.h"
1516
#include "llvm/ADT/STLExtras.h"
@@ -321,6 +322,10 @@ namespace mlir {
321322
namespace triton {
322323

323324
namespace gpu {
325+
326+
std::pair<SmallVector<LocalMemOpTile>, SmallVector<LocalMemOpTile>>
327+
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth);
328+
324329
Type getFunctionType(Type resultType, ValueRange operands);
325330

326331
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
@@ -608,10 +613,6 @@ std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp);
608613

609614
std::optional<LLVM::AtomicOrdering> getMemoryOrdering(MemSemantic memOrdering);
610615

611-
bool isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
612-
ArrayRef<int64_t> allocShape,
613-
triton::gpu::SharedEncodingTrait sharedEnc);
614-
615616
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx);
616617

617618
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type);
@@ -644,11 +645,10 @@ Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
644645
const LLVMTypeConverter *typeConverter,
645646
RewriterBase &rewriter);
646647

647-
LogicalResult
648-
transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
649-
const TargetInfoBase &targetInfo,
650-
const LLVMTypeConverter *typeConverter,
651-
RewriterBase &rewriter);
648+
void transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
649+
const TargetInfoBase &targetInfo,
650+
const LLVMTypeConverter *typeConverter,
651+
RewriterBase &rewriter);
652652

653653
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
654654
ArrayRef<Value> args,

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace mlir {
99

10+
// Bitwidth of pointers
11+
constexpr int kPtrBitWidth = 64;
12+
1013
template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
1114
SmallVector<T> out;
1215
for (const auto &i : in)
@@ -186,6 +189,7 @@ bool isHostSideDescriptor(Value v);
186189

187190
bool isKernel(FunctionOpInterface funcOp);
188191

192+
unsigned getBitwidth(RankedTensorType ty);
189193
} // namespace triton
190194
} // namespace mlir
191195

include/triton/Tools/GenericSwizzling.h

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,49 @@
44
#include "llvm/ADT/ArrayRef.h"
55
#include "llvm/ADT/SmallVector.h"
66
#include <cstdint>
7+
#include <utility>
78

89
namespace mlir::triton {
910
class LinearLayout;
10-
}
11+
class TargetInfoBase;
12+
} // namespace mlir::triton
1113

1214
namespace mlir::triton::gpu {
13-
LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
14-
int32_t bitwidth);
15+
// Store the lane indices that are used in the contiguous part
16+
// of an operation and in the address part.
17+
// The laneAddr part just represents the indices used in one wavefront
18+
// For now we just represent tiles with full vectorisation, meaning
19+
// ld.shared.b32.v4/st.shared.b32.v4
20+
// ldmatrix.v4 / stmatrix.v4
21+
// ldmatrix.trans.v4 / stmatrix.trans.v4
22+
struct LocalMemOpTile {
23+
// If laneContig.size() < log2(128/bitwidth), we assume that
24+
// the first log2(128/bitwidth) - laneContig.size() bases are registers
25+
llvm::SmallVector<int32_t> laneContig;
26+
// If laneAddr.size() < 3, we assume that the first
27+
// 3 - laneAddr.size() bases are registers
28+
llvm::SmallVector<int32_t> laneAddr;
29+
};
1530

16-
std::pair<int, int> logBankConflicts(const LinearLayout &src,
17-
const LinearLayout &dst,
31+
// Given a set of possible instructions given by
32+
// targetInfo.laneIdTiles(bitwidth) returns the optimal swizzling given these
33+
// instructions and a pair of indices into the ldStTiles that's needed to lower
34+
// this swizzling
35+
std::pair<LinearLayout, std::pair<int32_t, int32_t>>
36+
optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
37+
llvm::ArrayRef<LocalMemOpTile> srcTiles,
38+
llvm::ArrayRef<LocalMemOpTile> dstTiles, int32_t bitwidth);
39+
40+
LinearLayout optimalSwizzlingLdSt(const LinearLayout &src,
41+
const LinearLayout &dst, int32_t bitwidth);
42+
43+
std::pair<int, int> logBankConflictsLdSt(const LinearLayout &src,
44+
const LinearLayout &dst,
45+
const LinearLayout &smem,
46+
int32_t bitwidth);
47+
48+
std::pair<int, int> logBankConflicts(llvm::ArrayRef<int32_t> tileSrc,
49+
llvm::ArrayRef<int32_t> tileDst,
1850
const LinearLayout &smem,
1951
int32_t bitwidth);
2052
} // namespace mlir::triton::gpu

include/triton/Tools/LayoutUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128128
std::pair<int64_t, ColumnAction>
129-
actionAdditiveStrides(const LinearLayout &layout, uint64_t maskSpanOffsets);
129+
actionAdditiveStrides(const LinearLayout &layout, const LinearLayout addrLayout,
130+
uint64_t maskSpanOffsets);
130131

131132
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
132133
// the same broadcasting as layout

lib/Analysis/Allocation.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ namespace mlir {
2929
//===----------------------------------------------------------------------===//
3030
namespace triton {
3131

32-
// Bitwidth of pointers
33-
constexpr int kPtrBitWidth = 64;
3432
// Max shmem LDS/STS instruction in bits
3533
constexpr int kMaxShmemVecBitLength = 128;
3634

37-
static unsigned getBitwidth(RankedTensorType ty) {
38-
auto isPtr = isa<PointerType>(ty.getElementType());
39-
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
35+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
36+
RankedTensorType dstTy) {
37+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
38+
return getNumScratchElements(scratchConfig.paddedRepShape);
4039
}
4140

4241
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
@@ -47,17 +46,11 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4746
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
4847
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
4948
auto bitwidth = getBitwidth(srcTy);
50-
auto smem = gpu::optimalSwizzling(srcLayout, dstLayout, bitwidth);
49+
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
5150
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
5251
return smem.getTotalOutDimSize() / reps;
5352
}
5453

55-
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56-
RankedTensorType dstTy) {
57-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
58-
return getNumScratchElements(scratchConfig.paddedRepShape);
59-
}
60-
6154
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
6255
RankedTensorType dstTy) {
6356
Attribute srcLayout = srcTy.getEncoding();
@@ -215,10 +208,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
215208
auto dstTy = cvtLayout.getType();
216209
if (!cvtNeedsSharedMemory(srcTy, dstTy))
217210
return 0;
218-
// Pesimistically take the max. We will revisit later
219-
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
220-
getNumScratchElemsPaddedCvt(srcTy, dstTy));
221-
211+
// The generic pass uses swizzling
212+
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
222213
return elems * getBitwidth(srcTy) / 8;
223214
}
224215
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 145 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
6363
} else if (llvm::is_contained(dims, kWarp)) {
6464
// Case 2: Transfer between values in the same CTA, in which case we move
6565
// values through shared memory.
66-
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
66+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
67+
return success();
6768
} else if (llvm::is_contained(dims, kLane)) {
6869
// Case 3. Transfer between values in the same warp, in which case we try
6970
// to move values using warp shuffles, though if the pattern is
@@ -74,7 +75,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
7475
// TODO: Since data is only transferred within a warp over shared memory,
7576
// we should use `bar.warp.sync` instead of `barrier`, which will improve
7677
// latency when warps issue barriers on different cycles.
77-
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
78+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
79+
return success();
7880
} else if (llvm::is_contained(dims, kRegister)) {
7981
// Case 4. Transfer between values in the same thread, in which case we
8082
// simply reorder the elements of adaptor.getSrc().
@@ -110,24 +112,152 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110112
return success();
111113
}
112114

113-
LogicalResult transferWithinBlock(ConvertLayoutOp op,
114-
const LinearLayout &srcLayout,
115-
const LinearLayout &dstLayout,
116-
OpAdaptor adaptor,
117-
ConversionPatternRewriter &rewriter) const {
118-
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
115+
SmallVector<Value> transferWithinBlockSwizzlingImpl(
116+
Location loc, ConversionPatternRewriter &rewriter,
117+
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
118+
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
119+
auto *ctx = rewriter.getContext();
120+
auto b = TritonLLVMOpBuilder(loc, rewriter);
121+
// We handle transformations recursively as they all need a preprocessing
122+
// and a postprocessing step.
123+
124+
// Handle pointer types as 64-bit integers
125+
if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
126+
auto llvmElemTyPtr = i64_ty;
127+
auto newInVals = llvm::to_vector(llvm::map_range(inVals, [&](Value v) {
128+
return b.ptrtoint(llvmElemTyPtr, v).getResult();
129+
}));
130+
auto outVals =
131+
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
132+
newInVals, llvmElemTyPtr, smemBase);
133+
for (auto &v : outVals) {
134+
v = b.inttoptr(llvmElemTy, v);
135+
}
136+
return outVals;
137+
}
119138

120-
// Try to use swizzling to implement the conversion
121-
if (succeeded(transferWithinBlockSwizzling(op, adaptor.getSrc(), targetInfo,
122-
getTypeConverter(), rewriter))) {
123-
return success();
139+
// Handle sub-byte elements like i1
140+
if (llvmElemTy.getIntOrFloatBitWidth() < 8) {
141+
// Upcast to i8
142+
auto i8ElemTy = i8_ty;
143+
auto newInVals = llvm::to_vector(llvm::map_range(
144+
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
145+
auto outVals = transferWithinBlockSwizzlingImpl(
146+
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
147+
for (auto &v : outVals) {
148+
v = b.trunc(llvmElemTy, v);
149+
}
150+
return outVals;
124151
}
125152

126-
Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
127-
getTypeConverter(), rewriter);
153+
// Remove broadcasting in src
154+
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
155+
if (!removeBroadcastSrc.isIdentity()) {
156+
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
157+
auto newInVals = removeBroadcastSrc.apply(inVals);
158+
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
159+
newInVals, llvmElemTy, smemBase);
160+
}
128161

162+
// Remove broadcasting in dst
163+
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
164+
if (!removeBroadcastDst.isIdentity()) {
165+
auto prmtDst = removeBroadcastDst.apply(dstLayout);
166+
auto outVals = transferWithinBlockSwizzlingImpl(
167+
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
168+
return broadcastAs(outVals, dstLayout);
169+
}
170+
171+
// At this point we have a type that's at least 8-bit
172+
// and we don't have broadcasting in the registers
173+
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
174+
auto smem = optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
175+
176+
// Extract reps from smem
177+
auto kReg = str_attr("register");
178+
auto kReps = str_attr("reps");
179+
auto nReps = smem.getInDimSize(kReps);
180+
auto reps = LinearLayout::identity1D(nReps, kReg, kReps);
181+
182+
auto totalStoreCvt = srcLayout.invertAndCompose(smem);
183+
auto totalLoadCvt = dstLayout.invertAndCompose(smem);
184+
185+
// The permutation exists by construction of the reps dimension in
186+
// optimalSwizzling
187+
auto permStore =
188+
regPermForDivide(totalStoreCvt, reps, /*left=*/false).value();
189+
totalStoreCvt = permStore.apply(totalStoreCvt);
190+
auto permutedInVals = permStore.apply(inVals);
191+
auto permLoad =
192+
regPermForDivide(totalLoadCvt, reps, /*left=*/false).value();
193+
totalLoadCvt = permLoad.apply(totalLoadCvt);
194+
195+
// Remove the reps and flatten into offset
196+
auto storeCvt = *divideRight(totalStoreCvt, reps);
197+
auto loadCvt = *divideRight(totalLoadCvt, reps);
198+
auto kOffset = str_attr("offset");
199+
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
200+
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
201+
202+
auto tileSize = storeCvt.getInDimSize(kReg);
203+
204+
assert(permutedInVals.size() == tileSize * nReps);
205+
SmallVector<Value> outVals;
206+
auto affineOffset = b.i32_val(0);
207+
auto maskSpanAffineOffset = 0;
208+
auto noPaddingOffset = [](Value v) { return v; };
209+
for (int i = 0; i < nReps; ++i) {
210+
if (i > 0)
211+
b.barrier();
212+
213+
auto tileInVals =
214+
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
215+
// Store
216+
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
217+
noPaddingOffset, affineOffset, maskSpanAffineOffset,
218+
rewriter, targetInfo);
219+
b.barrier();
220+
// Load
221+
SmallVector<Value> tileOutVals = lowerLdStShared(
222+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
223+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
224+
llvm::append_range(outVals, tileOutVals);
225+
}
226+
227+
// Undo the permLoad used to divideRight
228+
outVals = permLoad.inverse().apply(outVals);
229+
return outVals;
230+
}
231+
232+
void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
233+
ConversionPatternRewriter &rewriter) const {
234+
auto loc = op.getLoc();
235+
auto *ctx = op.getContext();
236+
auto srcTy = op.getSrc().getType();
237+
auto dstTy = op.getType();
238+
239+
// Remove the kBlock dimension from the layout as it's the identity in the
240+
// cvt
241+
auto srcLayout = toLinearLayout(srcTy);
242+
auto dstLayout = toLinearLayout(dstTy);
243+
auto kReg = str_attr("register");
244+
auto kLane = str_attr("lane");
245+
auto kWarp = str_attr("warp");
246+
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
247+
to_vector(srcLayout.getOutDimNames()));
248+
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
249+
to_vector(dstLayout.getOutDimNames()));
250+
251+
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
252+
auto smemBase =
253+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
254+
auto inVals = unpackLLElements(loc, src, rewriter);
255+
auto outVals = transferWithinBlockSwizzlingImpl(
256+
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
257+
258+
Value result =
259+
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);
129260
rewriter.replaceOp(op, result);
130-
return success();
131261
}
132262

133263
// Use warp shuffles to implement a layout conversion where data only needs to

0 commit comments

Comments
 (0)