Skip to content

Commit d99536d

Browse files
binarmanjataylo
authored andcommitted
[AMD] remove redundant LDS bypass checks (triton-lang#5002)
This commit removes special cases for MFMA -> Dot Operand LDS shortcuts. Now it is supported by common linear layout infrastructure. No tests are added, mfma-shortcut.mlir already testing this. (cherry picked from commit 69f656c)
1 parent 4499262 commit d99536d

File tree

5 files changed

+6
-74
lines changed

5 files changed

+6
-74
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ bool atomicNeedsSharedMemory(Value result);
214214

215215
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
216216

217-
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218-
219217
// Return true if the src and dst layout match.
220218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
221219
RankedTensorType dstTy);

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
113113
Attribute srcLayout = srcTy.getEncoding();
114114
Attribute dstLayout = dstTy.getEncoding();
115115

116-
assert(!isMfmaToDotShortcut(srcTy, dstTy));
116+
assert(cvtNeedsSharedMemory(srcTy, dstTy));
117117

118118
// FIXME This is NOT entirely correct
119119
// This should be getElemOrder, but we don't have such a method

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
605605
return matrixDimsCompatible && bDimCompatible;
606606
}
607607

608-
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
609-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
610-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
611-
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
612-
return false;
613-
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
614-
// improved. In addition, we can enable this shortcut for regular MFMA
615-
// layout when opIdx == 1.
616-
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
617-
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
618-
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
619-
dotOperandLayout.getParent() == mfmaLayout &&
620-
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
621-
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
622-
}
623-
624608
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
625609
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
626610
RankedTensorType dstTy) {
@@ -738,8 +722,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
738722
// supported yet in Triton's backend.
739723
return !cvtReordersRegisters(srcTy, dstTy) &&
740724
!isBlockedToDotShortcut(srcTy, dstTy) &&
741-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
742-
!isMfmaToDotShortcut(srcTy, dstTy);
725+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
743726
}
744727

745728
bool atomicNeedsSharedMemory(Value value) {

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -115,64 +115,13 @@ struct LocalLoadOpConversion
115115
}
116116
};
117117

118-
struct ConvertLayoutOpConversion
119-
: public ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
120-
public:
121-
using ConvertOpToLLVMPattern<
122-
triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern;
123-
124-
LogicalResult
125-
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
126-
ConversionPatternRewriter &rewriter) const override {
127-
Value src = op.getSrc();
128-
Value dst = op.getResult();
129-
auto srcTy = cast<RankedTensorType>(src.getType());
130-
auto dstTy = cast<RankedTensorType>(dst.getType());
131-
Attribute srcLayout = srcTy.getEncoding();
132-
Attribute dstLayout = dstTy.getEncoding();
133-
134-
if (isa<AMDMfmaEncodingAttr>(srcLayout) &&
135-
isa<DotOperandEncodingAttr>(dstLayout)) {
136-
return lowerMfmaToDotOperand(op, adaptor, rewriter);
137-
}
138-
return failure();
139-
}
140-
141-
private:
142-
LogicalResult
143-
lowerMfmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
144-
ConversionPatternRewriter &rewriter) const {
145-
auto loc = op.getLoc();
146-
RankedTensorType srcTy = op.getSrc().getType();
147-
RankedTensorType dstTy = op.getType();
148-
if (isMfmaToDotShortcut(srcTy, dstTy)) {
149-
// vecSize is an number of sequential elements stored by one thread
150-
// - For MFMA encoding (encoding of the result tensor of dot
151-
// operation) it is 4
152-
// - For MFMA operand encoding it is
153-
// dotOperandEncoding::kWidth,
154-
// which is 4 in certain cases (e.g. fp16 and bfloat16 dtypes with kpack
155-
// = 1)
156-
//
157-
// For cases where these two values are equal MFMA and MFMA operand
158-
// layouts are the same.
159-
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
160-
Value view =
161-
packLLElements(loc, getTypeConverter(), vals, rewriter, dstTy);
162-
rewriter.replaceOp(op, view);
163-
return success();
164-
}
165-
return failure();
166-
}
167-
};
168118
} // namespace
169119

170120
namespace mlir::triton::AMD {
171121
void populateConvertLayoutOpToLLVMPatterns(
172122
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
173123
RewritePatternSet &patterns, int numWarps,
174124
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
175-
patterns.add<ConvertLayoutOpConversion>(typeConverter, benefit);
176125
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
177126
}
178127
} // namespace mlir::triton::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ struct DecomposeUnsupportedAMDConversions
3838

3939
triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod);
4040

41-
triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod,
42-
isMfmaToDotShortcut);
41+
auto isShortcut =
42+
mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory));
43+
44+
triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, isShortcut);
4345

4446
/* -------------------------------- */
4547
// Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op`

0 commit comments

Comments
 (0)