Skip to content

Commit 563c2c1

Browse files
Reland "[LAYOUTS] Implement generalized swizzling for convert_layout (#7565)" (#4897)
Fixes #4883 --------- Signed-off-by: Whitney Tsang <[email protected]>
2 parents 3c154c8 + 0d2fd2d commit 563c2c1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1366
-663
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8989
mlir::triton::registerConvertWarpSpecializeToLLVM();
9090
mlir::triton::registerConvertTritonGPUToLLVMPass();
9191
mlir::triton::registerConvertNVGPUToLLVMPass();
92+
mlir::triton::registerAllocateSharedMemoryNvPass();
9293
mlir::registerLLVMDIScope();
9394
mlir::triton::gpu::intel::registerTritonAnnotateModulePass();
9495
mlir::triton::gpu::intel::registerTritonIntelGPUPasses();

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,6 @@ class TargetInfoBase {
3838
pred);
3939
}
4040

41-
virtual bool canUseStMatrix(RankedTensorType tensorTy,
42-
ArrayRef<unsigned> repShape,
43-
ArrayRef<unsigned> paddedRepShape,
44-
ArrayRef<unsigned> order,
45-
int swizzleByteSize) const = 0;
46-
47-
virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
48-
Value ptr, Value val) const = 0;
49-
5041
virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val,
5142
int i) const = 0;
5243
virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1212
#include "triton/Dialect/TritonGPU/IR/Types.h"
13+
#include "triton/Tools/GenericSwizzling.h"
1314
#include "triton/Tools/LinearLayout.h"
1415
#include "triton/Tools/StrUtil.h"
1516
#include "llvm/ADT/STLExtras.h"
@@ -321,6 +322,10 @@ namespace mlir {
321322
namespace triton {
322323

323324
namespace gpu {
325+
326+
std::pair<SmallVector<LocalMemOpTile>, SmallVector<LocalMemOpTile>>
327+
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth);
328+
324329
Type getFunctionType(Type resultType, ValueRange operands);
325330

326331
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
@@ -608,10 +613,6 @@ std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp);
608613

609614
std::optional<LLVM::AtomicOrdering> getMemoryOrdering(MemSemantic memOrdering);
610615

611-
bool isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
612-
ArrayRef<int64_t> allocShape,
613-
triton::gpu::SharedEncodingTrait sharedEnc);
614-
615616
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx);
616617

617618
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type);

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace mlir {
99

10+
// Bitwidth of pointers
11+
constexpr int kPtrBitWidth = 64;
12+
1013
template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
1114
SmallVector<T> out;
1215
for (const auto &i : in)
@@ -186,6 +189,7 @@ bool isHostSideDescriptor(Value v);
186189

187190
bool isKernel(FunctionOpInterface funcOp);
188191

192+
unsigned getBitwidth(RankedTensorType ty);
189193
} // namespace triton
190194
} // namespace mlir
191195

include/triton/Tools/GenericSwizzling.h

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,49 @@
44
#include "llvm/ADT/ArrayRef.h"
55
#include "llvm/ADT/SmallVector.h"
66
#include <cstdint>
7+
#include <utility>
78

89
namespace mlir::triton {
910
class LinearLayout;
10-
}
11+
class TargetInfoBase;
12+
} // namespace mlir::triton
1113

1214
namespace mlir::triton::gpu {
13-
LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
14-
int32_t bitwidth);
15+
// Store the lane indices that are used in the contiguous part
16+
// of an operation and in the address part.
17+
// The laneAddr part just represents the indices used in one wavefront
18+
// For now we just represent tiles with full vectorisation, meaning
19+
// ld.shared.b32.v4/st.shared.b32.v4
20+
// ldmatrix.v4 / stmatrix.v4
21+
// ldmatrix.trans.v4 / stmatrix.trans.v4
22+
struct LocalMemOpTile {
23+
// If laneContig.size() < log2(128/bitwidth), we assume that
24+
// the first log2(128/bitwidth) - laneContig.size() bases are registers
25+
llvm::SmallVector<int32_t> laneContig;
26+
// If laneAddr.size() < 3, we assume that the first
27+
// 3 - laneAddr.size() bases are registers
28+
llvm::SmallVector<int32_t> laneAddr;
29+
};
1530

16-
std::pair<int, int> logBankConflicts(const LinearLayout &src,
17-
const LinearLayout &dst,
31+
// Given a set of possible instructions given by
32+
// targetInfo.laneIdTiles(bitwidth) returns the optimal swizzling given these
33+
// instructions and a pair of indices into the ldStTiles that's needed to lower
34+
// this swizzling
35+
std::pair<LinearLayout, std::pair<int32_t, int32_t>>
36+
optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
37+
llvm::ArrayRef<LocalMemOpTile> srcTiles,
38+
llvm::ArrayRef<LocalMemOpTile> dstTiles, int32_t bitwidth);
39+
40+
LinearLayout optimalSwizzlingLdSt(const LinearLayout &src,
41+
const LinearLayout &dst, int32_t bitwidth);
42+
43+
std::pair<int, int> logBankConflictsLdSt(const LinearLayout &src,
44+
const LinearLayout &dst,
45+
const LinearLayout &smem,
46+
int32_t bitwidth);
47+
48+
std::pair<int, int> logBankConflicts(llvm::ArrayRef<int32_t> tileSrc,
49+
llvm::ArrayRef<int32_t> tileDst,
1850
const LinearLayout &smem,
1951
int32_t bitwidth);
2052
} // namespace mlir::triton::gpu

include/triton/Tools/LayoutUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128128
std::pair<int64_t, ColumnAction>
129-
actionAdditiveStrides(const LinearLayout &layout, uint64_t maskSpanOffsets);
129+
actionAdditiveStrides(const LinearLayout &layout, const LinearLayout addrLayout,
130+
uint64_t maskSpanOffsets);
130131

131132
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
132133
// the same broadcasting as layout

lib/Analysis/Allocation.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ namespace mlir {
2929
//===----------------------------------------------------------------------===//
3030
namespace triton {
3131

32-
// Bitwidth of pointers
33-
constexpr int kPtrBitWidth = 64;
3432
// Max shmem LDS/STS instruction in bits
3533
constexpr int kMaxShmemVecBitLength = 128;
3634

37-
static unsigned getBitwidth(RankedTensorType ty) {
38-
auto isPtr = isa<PointerType>(ty.getElementType());
39-
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
35+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
36+
RankedTensorType dstTy) {
37+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
38+
return getNumScratchElements(scratchConfig.paddedRepShape);
4039
}
4140

4241
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
@@ -47,17 +46,11 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4746
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
4847
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
4948
auto bitwidth = getBitwidth(srcTy);
50-
auto smem = gpu::optimalSwizzling(srcLayout, dstLayout, bitwidth);
49+
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
5150
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
5251
return smem.getTotalOutDimSize() / reps;
5352
}
5453

55-
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56-
RankedTensorType dstTy) {
57-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
58-
return getNumScratchElements(scratchConfig.paddedRepShape);
59-
}
60-
6154
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
6255
RankedTensorType dstTy) {
6356
Attribute srcLayout = srcTy.getEncoding();

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
119119

120120
// Try to use swizzling to implement the conversion
121121
// HACK Remove once XPU tests pass for the swizzling path
122-
if (!targetInfo.isXpu() &&
123-
succeeded(transferWithinBlockSwizzling(op, adaptor.getSrc(), targetInfo,
124-
getTypeConverter(), rewriter))) {
122+
if (!targetInfo.isXpu()) {
123+
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
125124
return success();
126125
}
127126

@@ -132,6 +131,154 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
132131
return success();
133132
}
134133

134+
SmallVector<Value> transferWithinBlockSwizzlingImpl(
135+
Location loc, ConversionPatternRewriter &rewriter,
136+
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
137+
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
138+
auto *ctx = rewriter.getContext();
139+
auto b = TritonLLVMOpBuilder(loc, rewriter);
140+
// We handle transformations recursively as they all need a preprocessing
141+
// and a postprocessing step.
142+
143+
// Handle pointer types as 64-bit integers
144+
if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
145+
auto llvmElemTyPtr = i64_ty;
146+
auto newInVals = llvm::to_vector(llvm::map_range(inVals, [&](Value v) {
147+
return b.ptrtoint(llvmElemTyPtr, v).getResult();
148+
}));
149+
auto outVals =
150+
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
151+
newInVals, llvmElemTyPtr, smemBase);
152+
for (auto &v : outVals) {
153+
v = b.inttoptr(llvmElemTy, v);
154+
}
155+
return outVals;
156+
}
157+
158+
// Handle sub-byte elements like i1
159+
if (llvmElemTy.getIntOrFloatBitWidth() < 8) {
160+
// Upcast to i8
161+
auto i8ElemTy = i8_ty;
162+
auto newInVals = llvm::to_vector(llvm::map_range(
163+
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
164+
auto outVals = transferWithinBlockSwizzlingImpl(
165+
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
166+
for (auto &v : outVals) {
167+
v = b.trunc(llvmElemTy, v);
168+
}
169+
return outVals;
170+
}
171+
172+
// Remove broadcasting in src
173+
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
174+
if (!removeBroadcastSrc.isIdentity()) {
175+
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
176+
auto newInVals = removeBroadcastSrc.apply(inVals);
177+
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
178+
newInVals, llvmElemTy, smemBase);
179+
}
180+
181+
// Remove broadcasting in dst
182+
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
183+
if (!removeBroadcastDst.isIdentity()) {
184+
auto prmtDst = removeBroadcastDst.apply(dstLayout);
185+
auto outVals = transferWithinBlockSwizzlingImpl(
186+
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
187+
return broadcastAs(outVals, dstLayout);
188+
}
189+
190+
// At this point we have a type that's at least 8-bit
191+
// and we don't have broadcasting in the registers
192+
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
193+
auto smem = optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
194+
195+
// Extract reps from smem
196+
auto kReg = str_attr("register");
197+
auto kReps = str_attr("reps");
198+
auto nReps = smem.getInDimSize(kReps);
199+
auto reps = LinearLayout::identity1D(nReps, kReg, kReps);
200+
201+
auto totalStoreCvt = srcLayout.invertAndCompose(smem);
202+
auto totalLoadCvt = dstLayout.invertAndCompose(smem);
203+
204+
// The permutation exists by construction of the reps dimension in
205+
// optimalSwizzling
206+
auto permStore =
207+
regPermForDivide(totalStoreCvt, reps, /*left=*/false).value();
208+
totalStoreCvt = permStore.apply(totalStoreCvt);
209+
auto permutedInVals = permStore.apply(inVals);
210+
auto permLoad =
211+
regPermForDivide(totalLoadCvt, reps, /*left=*/false).value();
212+
totalLoadCvt = permLoad.apply(totalLoadCvt);
213+
214+
// Remove the reps and flatten into offset
215+
auto storeCvt = *divideRight(totalStoreCvt, reps);
216+
auto loadCvt = *divideRight(totalLoadCvt, reps);
217+
auto kOffset = str_attr("offset");
218+
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
219+
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
220+
221+
auto tileSize = storeCvt.getInDimSize(kReg);
222+
223+
assert(permutedInVals.size() == tileSize * nReps);
224+
SmallVector<Value> outVals;
225+
auto affineOffset = b.i32_val(0);
226+
auto maskSpanAffineOffset = 0;
227+
auto noPaddingOffset = [](Value v) { return v; };
228+
for (int i = 0; i < nReps; ++i) {
229+
if (i > 0)
230+
b.barrier();
231+
232+
auto tileInVals =
233+
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
234+
// Store
235+
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
236+
noPaddingOffset, affineOffset, maskSpanAffineOffset,
237+
rewriter, targetInfo);
238+
b.barrier();
239+
// Load
240+
SmallVector<Value> tileOutVals = lowerLdStShared(
241+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
242+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
243+
llvm::append_range(outVals, tileOutVals);
244+
}
245+
246+
// Undo the permLoad used to divideRight
247+
outVals = permLoad.inverse().apply(outVals);
248+
return outVals;
249+
}
250+
251+
void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
252+
ConversionPatternRewriter &rewriter) const {
253+
auto loc = op.getLoc();
254+
auto *ctx = op.getContext();
255+
auto srcTy = op.getSrc().getType();
256+
auto dstTy = op.getType();
257+
258+
// Remove the kBlock dimension from the layout as it's the identity in the
259+
// cvt
260+
auto srcLayout = toLinearLayout(srcTy);
261+
auto dstLayout = toLinearLayout(dstTy);
262+
auto kReg = str_attr("register");
263+
auto kLane = str_attr("lane");
264+
auto kWarp = str_attr("warp");
265+
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
266+
to_vector(srcLayout.getOutDimNames()));
267+
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
268+
to_vector(dstLayout.getOutDimNames()));
269+
270+
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
271+
auto smemBase =
272+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
273+
auto inVals = unpackLLElements(loc, src, rewriter);
274+
auto outVals = transferWithinBlockSwizzlingImpl(
275+
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
276+
277+
Value result =
278+
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);
279+
rewriter.replaceOp(op, result);
280+
}
281+
135282
// Use warp shuffles to implement a layout conversion where data only needs to
136283
// be moved within warps.
137284
LogicalResult transferWithinWarp(ConvertLayoutOp op, OpAdaptor adaptor,

0 commit comments

Comments
 (0)