@@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
536536 (elemTy.isInteger (8 ) && version >= 2 );
537537}
538538
539- bool isBlockedToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
539+ bool isBlockedToDotShortcut (RankedTensorType & srcTy, RankedTensorType & dstTy) {
540540 auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding ());
541541 auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
542542 if (blockedLayout == nullptr || dotOperandLayout == nullptr )
@@ -605,6 +605,22 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
605605 return matrixDimsCompatible && bDimCompatible;
606606}
607607
608+ bool isMfmaToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
609+ auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding ());
610+ auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
611+ if (mfmaLayout == nullptr || dotOperandLayout == nullptr )
612+ return false ;
613+ // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
614+ // improved. In addition, we can enable this shortcut for regular MFMA
615+ // layout when opIdx == 1.
616+ return mfmaLayout.getWarpsPerCTA ()[1 ] == 1 &&
617+ dotOperandLayout.getOpIdx () == 0 && mfmaLayout.getIsTransposed () &&
618+ dotOperandLayout.getKWidth () == getContigPerThread (mfmaLayout)[1 ] &&
619+ dotOperandLayout.getParent () == mfmaLayout &&
620+ (mfmaLayout.getMDim () == 32 || mfmaLayout.getMDim () == 16 ) &&
621+ (srcTy.getElementType ().isF16 () || srcTy.getElementType ().isBF16 ());
622+ }
623+
608624// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
609625bool matchMmaV3AndDotOperandLayout (RankedTensorType srcTy,
610626 RankedTensorType dstTy) {
@@ -639,46 +655,8 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
639655 toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
640656 if (!(srcLayout.has_value () && dstLayout.has_value ()))
641657 return std::nullopt ;
642- StringAttr kRegister = StringAttr::get (ctx, " register" );
643- StringAttr kLane = StringAttr::get (ctx, " lane" );
644- StringAttr kWarp = StringAttr::get (ctx, " warp" );
645- StringAttr kBlock = StringAttr::get (ctx, " block" );
646- auto numSrcRegs = srcLayout->getInDimSize (kRegister );
647- auto numDstRegs = dstLayout->getInDimSize (kRegister );
648- // The `invertAndCompose` function will generate a layout that is injective
649- // by assigning new output dimensions to free variables. For instance,
650- // consider a scenario where `srcLayout` has a free variable in the lane
651- // dimension, while `dstLayout` has two free variables in the lane
652- // dimension and also a larger number of registers.
653- // The injective form of `srcLayout` will add only a single additional row
654- // to the transformation matrix, whereas the injective form of `dstLayout`
655- // will add two additional rows. This discrepancy causes misleading results
656- // because the matrices end up with a different number of rows.
657- //
658- // Take `dstLayout ⋅ srcLayout^-1` as an example:
659- //
660- // - `injective(dstLayout)`: [n, m] → [n + 2, m]
661- // - `injective(srcLayout)`: [n, m] → [n + 1, m]
662- // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
663- // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
664- // 1] → [n + 2, n + 1]
665- //
666- // Here, the `(n + 1)`-th row added by `dstLayout` represents the free
667- // variable in registers, and the `(n + 2)`-th row represents the free
668- // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
669- // represents the free variable in lanes. As a result, the `(n + 1)`-th row
670- // in two layouts do not correspond to the same free variable.
671- //
672- // To address this issue, we pad the free variables in `srcLayout` and
673- // `dstLayout` to ensure they have the same number of registers. This
674- // guarantees that the resulting matrices have the same number of rows,
675- // ensuring consistency in the composition process.
676- auto numRegs = std::max (numSrcRegs, numDstRegs);
677- auto srcLayoutWithFreeRegs = srcLayout->resize (kRegister , numRegs);
678- auto dstLayoutWithFreeRegs = dstLayout->resize (kRegister , numRegs);
679658 // comp describes the layout function to create dst from src.
680- LinearLayout comp =
681- dstLayoutWithFreeRegs.invertAndCompose (srcLayoutWithFreeRegs);
659+ LinearLayout comp = dstLayout->invertAndCompose (*srcLayout);
682660 // We try to quotient by the largest subspace first
683661 auto dims = SmallVector<StringRef>{" block" , " warp" , " lane" , " register" };
684662 for (auto dim : dims) {
@@ -715,14 +693,15 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
715693}
716694
717695bool cvtNeedsSharedMemory (RankedTensorType srcTy, RankedTensorType dstTy) {
718- // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
719- // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
720- // checks.
696+ // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
697+ // `isBlockedToDotShortcut` and ` isMfmaToDotShortcut`) once they're fully
698+ // subsumed by the linear-layout checks.
721699 // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
722700 // supported yet in Triton's backend.
723701 return !cvtReordersRegisters (srcTy, dstTy) &&
724702 !isBlockedToDotShortcut (srcTy, dstTy) &&
725- !matchMmaV3AndDotOperandLayout (srcTy, dstTy);
703+ !isMmaToDotShortcut (srcTy, dstTy) &&
704+ !isMfmaToDotShortcut (srcTy, dstTy);
726705}
727706
728707bool atomicNeedsSharedMemory (Value value) {
@@ -732,6 +711,20 @@ bool atomicNeedsSharedMemory(Value value) {
732711 return true ;
733712}
734713
714+ bool isMmaToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
715+ if (matchMmaV3AndDotOperandLayout (srcTy, dstTy))
716+ return true ;
717+ // dot_op<opIdx=0, parent=#mma> = #mma
718+ // when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
719+ auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding ());
720+ auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
721+ return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor () == 2 &&
722+ mmaLayout.getWarpsPerCTA ()[1 ] == 1 &&
723+ dotOperandLayout.getOpIdx () == 0 &&
724+ dotOperandLayout.getParent () == mmaLayout &&
725+ !srcTy.getElementType ().isF32 ();
726+ }
727+
735728namespace {
736729
737730// / A data structure similar to SetVector but maintains
0 commit comments