Skip to content

Commit e8ef0bb

Browse files
authored
[AMD] Disable swap operands for fp8 matmul (triton-lang#5577)
We found regressions for moe kernel with fp8 inputs. This PR effectively reverts part of triton-lang#4767 and disables the swap-operand feature for fp8 inputs matmul kernels for now while we investigate the regression.
1 parent 194a21f commit e8ef0bb

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,23 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
337337
: OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
338338
nonKDim(nonKDim), kPack(kPack) {}
339339

340+
bool isChainDot(tt::DotOp &dotOp) const {
341+
auto filter = [&dotOp](Operation *op) {
342+
return op->getParentRegion() == dotOp->getParentRegion();
343+
};
344+
ForwardSliceOptions fwdOpt;
345+
fwdOpt.filter = filter;
346+
BackwardSliceOptions bwdOpt;
347+
bwdOpt.omitBlockArguments = true;
348+
bwdOpt.filter = filter;
349+
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
350+
for (Operation *op : slices) {
351+
if (isa<tt::DotOp>(op) && (op != dotOp))
352+
return true;
353+
}
354+
return false;
355+
}
356+
340357
bool isSecondDot(tt::DotOp &dotOp) const {
341358
auto filter = [&dotOp](Operation *op) {
342359
return op->getParentRegion() == dotOp->getParentRegion();
@@ -391,12 +408,16 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
391408
auto warpsPerTile =
392409
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});
393410

394-
// Always use transposed mfma layout. This enables larger vectorization
395-
// for global store instructions
411+
// Use transposed mfma layout to enable larger vectorization for global
412+
// store instructions, except for fp8 matmul kernels due to regression
413+
// TODO (lixun): investigate the regression and enable this feature again
414+
auto aElemTy = mfmaInstr.getElementTypeA();
415+
bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ();
416+
bool isTransposed = isChainDot(dotOp) || !isFP8;
396417
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
397418
oldRetType.getContext(),
398419
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
399-
/*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout);
420+
/*instrShape*/ mDim, nDim, isTransposed, CTALayout);
400421

401422
Type mfmaAccType;
402423
if (oldRetType.getElementType().isIntOrIndex())

0 commit comments

Comments
 (0)