@@ -337,6 +337,23 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
337337 : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
338338 nonKDim (nonKDim), kPack(kPack ) {}
339339
340+ bool isChainDot (tt::DotOp &dotOp) const {
341+ auto filter = [&dotOp](Operation *op) {
342+ return op->getParentRegion () == dotOp->getParentRegion ();
343+ };
344+ ForwardSliceOptions fwdOpt;
345+ fwdOpt.filter = filter;
346+ BackwardSliceOptions bwdOpt;
347+ bwdOpt.omitBlockArguments = true ;
348+ bwdOpt.filter = filter;
349+ auto slices = getSlice (dotOp, bwdOpt, fwdOpt);
350+ for (Operation *op : slices) {
351+ if (isa<tt::DotOp>(op) && (op != dotOp))
352+ return true ;
353+ }
354+ return false ;
355+ }
356+
340357 bool isSecondDot (tt::DotOp &dotOp) const {
341358 auto filter = [&dotOp](Operation *op) {
342359 return op->getParentRegion () == dotOp->getParentRegion ();
@@ -391,12 +408,16 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
391408 auto warpsPerTile =
392409 warpsPerTileMFMA (dotOp, retShape, numWarps, {mDim , nDim});
393410
394- // Always use transposed mfma layout. This enables larger vectorization
395- // for global store instructions
411+ // Use transposed mfma layout to enable larger vectorization for global
412+ // store instructions, except for fp8 matmul kernels due to regression
413+ // TODO (lixun): investigate the regression and enable this feature again
414+ auto aElemTy = mfmaInstr.getElementTypeA ();
415+ bool isFP8 = aElemTy.isFloat8E5M2FNUZ () || aElemTy.isFloat8E4M3FNUZ ();
416+ bool isTransposed = isChainDot (dotOp) || !isFP8;
396417 mfmaEnc = ttg::AMDMfmaEncodingAttr::get (
397418 oldRetType.getContext (),
398419 /* versionMajor*/ mfmaVersion, /* versionMinor*/ 0 , warpsPerTile,
399- /* instrShape*/ mDim , nDim, /* isTransposed*/ true , CTALayout);
420+ /* instrShape*/ mDim , nDim, isTransposed, CTALayout);
400421
401422 Type mfmaAccType;
402423 if (oldRetType.getElementType ().isIntOrIndex ())
0 commit comments