Skip to content

Commit 58b35fe

Browse files
Revert "[LAYOUTS] Implement generalized swizzling for convert_layout (#7565)"
This reverts commit b3b9931.
1 parent 661f264 commit 58b35fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+650
-1358
lines changed

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8989
mlir::triton::registerConvertWarpSpecializeToLLVM();
9090
mlir::triton::registerConvertTritonGPUToLLVMPass();
9191
mlir::triton::registerConvertNVGPUToLLVMPass();
92-
mlir::triton::registerAllocateSharedMemoryNvPass();
9392
mlir::registerLLVMDIScope();
9493
mlir::triton::gpu::intel::registerTritonAnnotateModulePass();
9594
mlir::triton::gpu::intel::registerTritonIntelGPUPasses();

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ class TargetInfoBase {
3838
pred);
3939
}
4040

41+
virtual bool canUseStMatrix(RankedTensorType tensorTy,
42+
ArrayRef<unsigned> repShape,
43+
ArrayRef<unsigned> paddedRepShape,
44+
ArrayRef<unsigned> order,
45+
int swizzleByteSize) const = 0;
46+
47+
virtual void storeMatrixShared(RewriterBase &rewriter, Location loc,
48+
Value ptr, Value val) const = 0;
49+
4150
virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val,
4251
int i) const = 0;
4352
virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1212
#include "triton/Dialect/TritonGPU/IR/Types.h"
13-
#include "triton/Tools/GenericSwizzling.h"
1413
#include "triton/Tools/LinearLayout.h"
1514
#include "triton/Tools/StrUtil.h"
1615
#include "llvm/ADT/STLExtras.h"
@@ -322,10 +321,6 @@ namespace mlir {
322321
namespace triton {
323322

324323
namespace gpu {
325-
326-
std::pair<SmallVector<LocalMemOpTile>, SmallVector<LocalMemOpTile>>
327-
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth);
328-
329324
Type getFunctionType(Type resultType, ValueRange operands);
330325

331326
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
@@ -612,6 +607,10 @@ std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp);
612607

613608
std::optional<LLVM::AtomicOrdering> getMemoryOrdering(MemSemantic memOrdering);
614609

610+
bool isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
611+
ArrayRef<int64_t> allocShape,
612+
triton::gpu::SharedEncodingTrait sharedEnc);
613+
615614
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx);
616615

617616
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type);

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
namespace mlir {
99

10-
// Bitwidth of pointers
11-
constexpr int kPtrBitWidth = 64;
12-
1310
template <typename T, typename U> SmallVector<T> convertType(ArrayRef<U> in) {
1411
SmallVector<T> out;
1512
for (const auto &i : in)
@@ -189,7 +186,6 @@ bool isHostSideDescriptor(Value v);
189186

190187
bool isKernel(FunctionOpInterface funcOp);
191188

192-
unsigned getBitwidth(RankedTensorType ty);
193189
} // namespace triton
194190
} // namespace mlir
195191

include/triton/Tools/GenericSwizzling.h

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,17 @@
44
#include "llvm/ADT/ArrayRef.h"
55
#include "llvm/ADT/SmallVector.h"
66
#include <cstdint>
7-
#include <utility>
87

98
namespace mlir::triton {
109
class LinearLayout;
11-
class TargetInfoBase;
12-
} // namespace mlir::triton
10+
}
1311

1412
namespace mlir::triton::gpu {
15-
// Store the lane indices that are used in the contiguous part
16-
// of an operation and in the address part.
17-
// The laneAddr part just represents the indices used in one wavefront
18-
// For now we just represent tiles with full vectorisation, meaning
19-
// ld.shared.b32.v4/st.shared.b32.v4
20-
// ldmatrix.v4 / stmatrix.v4
21-
// ldmatrix.trans.v4 / stmatrix.trans.v4
22-
struct LocalMemOpTile {
23-
// If laneContig.size() < log2(128/bitwidth), we assume that
24-
// the first log2(128/bitwidth) - laneContig.size() bases are registers
25-
llvm::SmallVector<int32_t> laneContig;
26-
// If laneAddr.size() < 3, we assume that the first
27-
// 3 - laneAddr.size() bases are registers
28-
llvm::SmallVector<int32_t> laneAddr;
29-
};
13+
LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
14+
int32_t bitwidth);
3015

31-
// Given a set of possible instructions given by
32-
// targetInfo.laneIdTiles(bitwidth) returns the optimal swizzling given these
33-
// instructions and a pair of indices into the ldStTiles that's needed to lower
34-
// this swizzling
35-
std::pair<LinearLayout, std::pair<int32_t, int32_t>>
36-
optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
37-
llvm::ArrayRef<LocalMemOpTile> srcTiles,
38-
llvm::ArrayRef<LocalMemOpTile> dstTiles, int32_t bitwidth);
39-
40-
LinearLayout optimalSwizzlingLdSt(const LinearLayout &src,
41-
const LinearLayout &dst, int32_t bitwidth);
42-
43-
std::pair<int, int> logBankConflictsLdSt(const LinearLayout &src,
44-
const LinearLayout &dst,
45-
const LinearLayout &smem,
46-
int32_t bitwidth);
47-
48-
std::pair<int, int> logBankConflicts(llvm::ArrayRef<int32_t> tileSrc,
49-
llvm::ArrayRef<int32_t> tileDst,
16+
std::pair<int, int> logBankConflicts(const LinearLayout &src,
17+
const LinearLayout &dst,
5018
const LinearLayout &smem,
5119
int32_t bitwidth);
5220
} // namespace mlir::triton::gpu

include/triton/Tools/LayoutUtils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128128
std::pair<int64_t, ColumnAction>
129-
actionAdditiveStrides(const LinearLayout &layout, const LinearLayout addrLayout,
130-
uint64_t maskSpanOffsets);
129+
actionAdditiveStrides(const LinearLayout &layout, uint64_t maskSpanOffsets);
131130

132131
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
133132
// the same broadcasting as layout

lib/Analysis/Allocation.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ namespace mlir {
2929
//===----------------------------------------------------------------------===//
3030
namespace triton {
3131

32+
// Bitwidth of pointers
33+
constexpr int kPtrBitWidth = 64;
3234
// Max shmem LDS/STS instruction in bits
3335
constexpr int kMaxShmemVecBitLength = 128;
3436

35-
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
36-
RankedTensorType dstTy) {
37-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
38-
return getNumScratchElements(scratchConfig.paddedRepShape);
37+
static unsigned getBitwidth(RankedTensorType ty) {
38+
auto isPtr = isa<PointerType>(ty.getElementType());
39+
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
3940
}
4041

4142
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
@@ -46,11 +47,17 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4647
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
4748
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
4849
auto bitwidth = getBitwidth(srcTy);
49-
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
50+
auto smem = gpu::optimalSwizzling(srcLayout, dstLayout, bitwidth);
5051
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
5152
return smem.getTotalOutDimSize() / reps;
5253
}
5354

55+
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56+
RankedTensorType dstTy) {
57+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
58+
return getNumScratchElements(scratchConfig.paddedRepShape);
59+
}
60+
5461
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
5562
RankedTensorType dstTy) {
5663
Attribute srcLayout = srcTy.getEncoding();
@@ -208,8 +215,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
208215
auto dstTy = cvtLayout.getType();
209216
if (!cvtNeedsSharedMemory(srcTy, dstTy))
210217
return 0;
211-
// The generic pass uses swizzling
212-
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
218+
// Pesimistically take the max. We will revisit later
219+
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
220+
getNumScratchElemsPaddedCvt(srcTy, dstTy));
221+
213222
return elems * getBitwidth(srcTy) / 8;
214223
}
215224
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
6363
} else if (llvm::is_contained(dims, kWarp)) {
6464
// Case 2: Transfer between values in the same CTA, in which case we move
6565
// values through shared memory.
66-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
67-
return success();
66+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
6867
} else if (llvm::is_contained(dims, kLane)) {
6968
// Case 3. Transfer between values in the same warp, in which case we try
7069
// to move values using warp shuffles, though if the pattern is
@@ -75,8 +74,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
7574
// TODO: Since data is only transferred within a warp over shared memory,
7675
// we should use `bar.warp.sync` instead of `barrier`, which will improve
7776
// latency when warps issue barriers on different cycles.
78-
transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter);
79-
return success();
77+
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
8078
} else if (llvm::is_contained(dims, kRegister)) {
8179
// Case 4. Transfer between values in the same thread, in which case we
8280
// simply reorder the elements of adaptor.getSrc().
@@ -171,7 +169,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
171169
// At this point we have a type that's at least 8-bit
172170
// and we don't have broadcasting in the registers
173171
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
174-
auto smem = optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
172+
auto smem = optimalSwizzling(srcLayout, dstLayout, bitwidth);
175173

176174
// Extract reps from smem
177175
auto kReg = str_attr("register");
@@ -203,9 +201,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
203201

204202
assert(permutedInVals.size() == tileSize * nReps);
205203
SmallVector<Value> outVals;
204+
auto noPaddingOffset = [](Value v) { return v; };
206205
auto affineOffset = b.i32_val(0);
207206
auto maskSpanAffineOffset = 0;
208-
auto noPaddingOffset = [](Value v) { return v; };
209207
for (int i = 0; i < nReps; ++i) {
210208
if (i > 0)
211209
b.barrier();
@@ -229,8 +227,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
229227
return outVals;
230228
}
231229

232-
void transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
233-
ConversionPatternRewriter &rewriter) const {
230+
LogicalResult
231+
transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
232+
ConversionPatternRewriter &rewriter) const {
233+
// Fallback for now to standard lowering if it can use stmatrix
234+
auto scratchConfig =
235+
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
236+
bool isStMatrix = targetInfo.canUseStMatrix(
237+
op.getSrc().getType(), scratchConfig.repShape,
238+
scratchConfig.paddedRepShape, scratchConfig.order,
239+
/*swizzleByteSize=*/0);
240+
if (isStMatrix) {
241+
return failure();
242+
}
243+
234244
auto loc = op.getLoc();
235245
auto *ctx = op.getContext();
236246
auto srcTy = op.getSrc().getType();
@@ -258,6 +268,28 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
258268
Value result =
259269
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);
260270
rewriter.replaceOp(op, result);
271+
return success();
272+
}
273+
274+
LogicalResult transferWithinBlock(ConvertLayoutOp op,
275+
const LinearLayout &srcLayout,
276+
const LinearLayout &dstLayout,
277+
OpAdaptor adaptor,
278+
ConversionPatternRewriter &rewriter) const {
279+
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
280+
281+
// Try to use swizzling to implement the conversion
282+
// HACK Remove once AMD tests pass for the swizzling path
283+
if (targetInfo.isCuda() && succeeded(transferWithinBlockSwizzling(
284+
op, adaptor.getSrc(), rewriter))) {
285+
return success();
286+
}
287+
288+
Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
289+
getTypeConverter(), rewriter);
290+
291+
rewriter.replaceOp(op, result);
292+
return success();
261293
}
262294

263295
// Use warp shuffles to implement a layout conversion where data only needs to

0 commit comments

Comments
 (0)