Skip to content

Commit ef210bc

Browse files
authored
Revert "Revert "[BACKEND] Localize the use and definition of getShapePerCTATile in the AMD backend and aim for elimination (#7740)"" (#4992)
This reverts commit 59ac5e8. Signed-off-by: Anatoly Myachev <[email protected]>
1 parent a57ac88 commit ef210bc

File tree

17 files changed

+457
-661
lines changed

17 files changed

+457
-661
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,54 +20,9 @@ using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
2020

2121
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
2222

23-
// To convert a tensor from one layout to another, we need to allocate a
24-
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
25-
// require multiple iterations, with each iteration involving multiple
26-
// vectorized loads/stores. The scratch buffer has a shape (`repShape`) that
27-
// represents the maximum size accessed in each dimension during each iteration.
28-
// It is padded (`paddedRepShape`) to avoid bank conflicts and is accessed in a
29-
// specific `order`.
30-
struct ScratchConfig {
31-
SmallVector<unsigned> repShape;
32-
SmallVector<unsigned> paddedRepShape;
33-
SmallVector<unsigned> order;
34-
unsigned inVec;
35-
unsigned outVec;
36-
37-
ScratchConfig(SmallVector<unsigned> repShape,
38-
SmallVector<unsigned> paddedRepShape, unsigned inVec = 1,
39-
unsigned outVec = 1)
40-
: repShape(repShape), paddedRepShape(paddedRepShape), inVec(inVec),
41-
outVec(outVec) {}
42-
43-
void print(llvm::raw_ostream &os) const {
44-
os << "repShape: [";
45-
llvm::interleaveComma(repShape, os);
46-
os << "]";
47-
os << ", paddedRepShape: [";
48-
llvm::interleaveComma(paddedRepShape, os);
49-
os << "]";
50-
os << ", order: [";
51-
llvm::interleaveComma(order, os);
52-
os << "]";
53-
os << ", inVec: " << inVec << ", outVec: " << outVec << "\n";
54-
}
55-
};
56-
57-
// For a layout conversion between `srcTy` and `dstTy`, return the vector length
58-
// that can be used for the stores to and loads from shared memory,
59-
// respectively.
60-
std::pair</*inVec*/ unsigned, /*outVec*/ unsigned>
61-
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy);
62-
63-
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
64-
RankedTensorType dstTy);
65-
6623
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
6724
RankedTensorType dstTy);
6825

69-
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
70-
RankedTensorType dstTy);
7126
} // namespace triton
7227

7328
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -627,26 +627,6 @@ void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
627627
// Set the correct loop annotation on LLVM branch ops.
628628
void fixUpLoopAnnotation(ModuleOp mod);
629629

630-
/// Converts ConverLayoutOp to llvm using padded pattern.
631-
/// This pattern adds unused memory locations after every rows of tensor fastest
632-
/// changing dimension:
633-
/// e0 e1 e2 e3 p p \
634-
/// e4 e5 e6 e7 p p \
635-
/// ...
636-
/// e e e e p p
637-
/// Dimension order is chosen in order to use wide output reads.
638-
///
639-
/// \param op operation to convert
640-
/// \param src llvm structure containing operation input
641-
/// \param targetInfo
642-
/// \param typeConverter
643-
/// \param rewriter
644-
/// \returns llvm structure containing converted output
645-
Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
646-
const TargetInfoBase &targetInfo,
647-
const LLVMTypeConverter *typeConverter,
648-
RewriterBase &rewriter);
649-
650630
LogicalResult
651631
transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
652632
const TargetInfoBase &targetInfo,

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,6 @@ SmallVector<unsigned> getCTASplitNum(Attribute layout);
210210

211211
SmallVector<unsigned> getCTAOrder(Attribute layout);
212212

213-
/* The difference between ShapePerCTATile and ShapePerCTA:
214-
* (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp *
215-
* WarpsPerCTA in each dimension and is independent from the tensor shape.
216-
* (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension.
217-
* (3) In the implementation of emitIndices, ShapePerCTATile will
218-
* be replicated or wrapped to fit ShapePerCTA.
219-
*/
220-
// [FIXME LL] Kill this function
221-
SmallVector<unsigned> getShapePerCTATile(RankedTensorType layout);
222-
223213
// Returns the "logical" shape per CTA.
224214
// When shape and CTASplitNum have different number of dimensions, we assume
225215
// only the last N between common dimensions are split.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,6 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
751751
SmallVector<unsigned> getWarpsPerCTA() const;
752752

753753
// [FIXME LL] Supports legacy behaviour. We should remove these functions
754-
SmallVector<unsigned> getShapePerCTATile() const;
755754
SmallVector<unsigned> getSizePerThread() const;
756755
}];
757756

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,6 @@ namespace mlir {
2929
//===----------------------------------------------------------------------===//
3030
namespace triton {
3131

32-
// Max shmem LDS/STS instruction in bits
33-
constexpr int kMaxShmemVecBitLength = 128;
34-
35-
unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
36-
RankedTensorType dstTy) {
37-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
38-
return getNumScratchElements(scratchConfig.paddedRepShape);
39-
}
40-
4132
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
4233
RankedTensorType dstTy) {
4334
auto *ctx = srcTy.getContext();
@@ -51,40 +42,6 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
5142
return smem.getTotalOutDimSize() / reps;
5243
}
5344

54-
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
55-
RankedTensorType dstTy) {
56-
Attribute srcLayout = srcTy.getEncoding();
57-
Attribute dstLayout = dstTy.getEncoding();
58-
59-
if (!cvtNeedsSharedMemory(srcTy, dstTy)) {
60-
return {};
61-
}
62-
63-
if (shouldUseDistSmem(srcLayout, dstLayout)) {
64-
// TODO: padding to avoid bank conflicts
65-
return convertType<unsigned, int64_t>(gpu::getShapePerCTA(srcTy));
66-
}
67-
68-
assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()");
69-
70-
auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
71-
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
72-
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcTy);
73-
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstTy);
74-
75-
assert(srcTy.getRank() == dstTy.getRank() &&
76-
"src and dst must have the same rank");
77-
78-
unsigned rank = dstTy.getRank();
79-
SmallVector<unsigned> repShape(rank);
80-
for (unsigned d = 0; d < rank; ++d) {
81-
repShape[d] =
82-
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
83-
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
84-
}
85-
return repShape;
86-
}
87-
8845
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
8946
// because Triton's block-based programming model ensures that
9047
// all threads sharing the same partition of the tensor see the same values,
@@ -99,7 +56,7 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
9956
return variableMask.second != 0;
10057
})) {
10158
// The tensor has broadcasted dimensions
102-
smemShape = gpu::getShapePerCTATile(tensorTy);
59+
smemShape = convertType<unsigned>(gpu::getShapePerCTA(tensorTy));
10360
}
10461
} else {
10562
// If the result is a scalar, we need to allocate a single element.
@@ -109,80 +66,6 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
10966
return smemShape;
11067
}
11168

112-
std::pair<unsigned, unsigned>
113-
getScratchCvtInOutVecLengths(RankedTensorType srcTy, RankedTensorType dstTy) {
114-
Attribute srcLayout = srcTy.getEncoding();
115-
Attribute dstLayout = dstTy.getEncoding();
116-
117-
auto srcLinAttr = gpu::toLinearEncoding(srcTy);
118-
auto dstLinAttr = gpu::toLinearEncoding(dstTy);
119-
auto inOrd = srcLinAttr.getOrder();
120-
auto outOrd = dstLinAttr.getOrder();
121-
122-
unsigned rank = srcTy.getRank();
123-
124-
unsigned srcContigPerThread = srcLinAttr.getContigPerThread()[inOrd[0]];
125-
unsigned dstContigPerThread = dstLinAttr.getContigPerThread()[outOrd[0]];
126-
// TODO: Fix the legacy issue that outOrd[0] == 0 always means
127-
// that we cannot do vectorization.
128-
unsigned innerDim = rank - 1;
129-
unsigned inVec = outOrd[0] != innerDim ? 1
130-
: inOrd[0] != innerDim ? 1
131-
: srcContigPerThread;
132-
unsigned outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread;
133-
134-
if (isa<gpu::NvidiaMmaEncodingAttr>(srcLayout) &&
135-
isa<gpu::BlockedEncodingAttr>(dstLayout)) {
136-
// when storing from mma layout and loading in blocked layout vectorizing
137-
// the load back gives better performance even if there is a
138-
// transposition.
139-
outVec = dstContigPerThread;
140-
}
141-
return {inVec, outVec};
142-
}
143-
144-
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
145-
RankedTensorType dstTy) {
146-
// Initialize vector sizes and stride
147-
auto repShape = getRepShapeForCvt(srcTy, dstTy);
148-
if (repShape.empty())
149-
return ScratchConfig({}, {});
150-
ScratchConfig scratchConfig(repShape, repShape);
151-
auto rank = repShape.size();
152-
Attribute srcLayout = srcTy.getEncoding();
153-
Attribute dstLayout = dstTy.getEncoding();
154-
155-
assert(cvtNeedsSharedMemory(srcTy, dstTy));
156-
auto outOrd = gpu::getOrder(dstTy);
157-
scratchConfig.order = outOrd;
158-
159-
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
160-
getScratchCvtInOutVecLengths(srcTy, dstTy);
161-
// We can't write a longer vector than the shape of shared memory.
162-
// This shape might be smaller than the tensor shape in case we decided to
163-
// do the conversion in multiple iterations.
164-
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
165-
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
166-
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
167-
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
168-
// is the max vectorisation
169-
auto inBitWidth = getBitwidth(srcTy);
170-
auto outBitWidth = getBitwidth(dstTy);
171-
scratchConfig.inVec =
172-
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
173-
scratchConfig.outVec =
174-
std::min(scratchConfig.outVec, kMaxShmemVecBitLength / outBitWidth);
175-
176-
// No padding is required if the tensor is 1-D, or if all dimensions except
177-
// the first accessed dimension have a size of 1.
178-
if (rank <= 1 || product(repShape) == repShape[outOrd[0]])
179-
return scratchConfig;
180-
181-
auto paddedSize = std::max(scratchConfig.inVec, scratchConfig.outVec);
182-
scratchConfig.paddedRepShape[outOrd[0]] += paddedSize;
183-
return scratchConfig;
184-
}
185-
18669
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
18770
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
18871
ReduceOpHelper helper(reduceOp);

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ using namespace mlir;
2222
using namespace mlir::triton::gpu;
2323

2424
constexpr int kPtrBitWidth = 64;
25-
struct ConvertLayoutOpUsingLinearLayoutsConversion
25+
struct ConvertLayoutOpConversion
2626
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
2727
const TargetInfoBase &targetInfo;
2828

2929
// Set benefit to 2 so that this pattern applies before other convert-layout
3030
// conversions. TODO(jlebar): Eventually we want this to be the only pattern.
31-
explicit ConvertLayoutOpUsingLinearLayoutsConversion(
32-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
33-
PatternBenefit benefit = 1)
31+
explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
32+
const TargetInfoBase &targetInfo,
33+
PatternBenefit benefit = 1)
3434
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
3535
}
3636

@@ -590,6 +590,5 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
590590
void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
591591
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
592592
RewritePatternSet &patterns, PatternBenefit benefit) {
593-
patterns.add<ConvertLayoutOpUsingLinearLayoutsConversion>(
594-
typeConverter, targetInfo, benefit);
593+
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
595594
}

0 commit comments

Comments
 (0)