@@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
536536 (elemTy.isInteger (8 ) && version >= 2 );
537537}
538538
539- bool isBlockedToDotShortcut (RankedTensorType & srcTy, RankedTensorType & dstTy) {
539+ bool isBlockedToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
540540 auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding ());
541541 auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
542542 if (blockedLayout == nullptr || dotOperandLayout == nullptr )
@@ -605,22 +605,6 @@ bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
605605 return matrixDimsCompatible && bDimCompatible;
606606}
607607
608- bool isMfmaToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
609- auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding ());
610- auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
611- if (mfmaLayout == nullptr || dotOperandLayout == nullptr )
612- return false ;
613- // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
614- // improved. In addition, we can enable this shortcut for regular MFMA
615- // layout when opIdx == 1.
616- return mfmaLayout.getWarpsPerCTA ()[1 ] == 1 &&
617- dotOperandLayout.getOpIdx () == 0 && mfmaLayout.getIsTransposed () &&
618- dotOperandLayout.getKWidth () == getContigPerThread (mfmaLayout)[1 ] &&
619- dotOperandLayout.getParent () == mfmaLayout &&
620- (mfmaLayout.getMDim () == 32 || mfmaLayout.getMDim () == 16 ) &&
621- (srcTy.getElementType ().isF16 () || srcTy.getElementType ().isBF16 ());
622- }
623-
624608// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
625609bool matchMmaV3AndDotOperandLayout (RankedTensorType srcTy,
626610 RankedTensorType dstTy) {
@@ -655,8 +639,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
655639 toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
656640 if (!(srcLayout.has_value () && dstLayout.has_value ()))
657641 return std::nullopt ;
642+ StringAttr kRegister = StringAttr::get (ctx, " register" );
643+ StringAttr kLane = StringAttr::get (ctx, " lane" );
644+ StringAttr kWarp = StringAttr::get (ctx, " warp" );
645+ StringAttr kBlock = StringAttr::get (ctx, " block" );
646+ auto numSrcRegs = srcLayout->getInDimSize (kRegister );
647+ auto numDstRegs = dstLayout->getInDimSize (kRegister );
648+ // The `invertAndCompose` function will generate a layout that is injective
649+ // by assigning new output dimensions to free variables. For instance,
650+ // consider a scenario where `srcLayout` has a free variable in the lane
651+ // dimension, while `dstLayout` has two free variables in the lane
652+ // dimension and also a larger number of registers.
653+ // The injective form of `srcLayout` will add only a single additional row
654+ // to the transformation matrix, whereas the injective form of `dstLayout`
655+ // will add two additional rows. This discrepancy causes misleading results
656+ // because the matrices end up with a different number of rows.
657+ //
658+ // Take `dstLayout ⋅ srcLayout^-1` as an example:
659+ //
660+ // - `injective(dstLayout)`: [n, m] → [n + 2, m]
661+ // - `injective(srcLayout)`: [n, m] → [n + 1, m]
662+ // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
663+ // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
664+ // 1] → [n + 2, n + 1]
665+ //
666+ // Here, the `(n + 1)`-th row added by `dstLayout` represents the free
667+ // variable in registers, and the `(n + 2)`-th row represents the free
668+ // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
669+ // represents the free variable in lanes. As a result, the `(n + 1)`-th row
670+ // in two layouts do not correspond to the same free variable.
671+ //
672+ // To address this issue, we pad the free variables in `srcLayout` and
673+ // `dstLayout` to ensure they have the same number of registers. This
674+ // guarantees that the resulting matrices have the same number of rows,
675+ // ensuring consistency in the composition process.
676+ auto numRegs = std::max (numSrcRegs, numDstRegs);
677+ auto srcLayoutWithFreeRegs = srcLayout->resize (kRegister , numRegs);
678+ auto dstLayoutWithFreeRegs = dstLayout->resize (kRegister , numRegs);
658679 // comp describes the layout function to create dst from src.
659- LinearLayout comp = dstLayout->invertAndCompose (*srcLayout);
680+ LinearLayout comp =
681+ dstLayoutWithFreeRegs.invertAndCompose (srcLayoutWithFreeRegs);
660682 // We try to quotient by the largest subspace first
661683 auto dims = SmallVector<StringRef>{" block" , " warp" , " lane" , " register" };
662684 for (auto dim : dims) {
@@ -693,15 +715,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
693715}
694716
695717bool cvtNeedsSharedMemory (RankedTensorType srcTy, RankedTensorType dstTy) {
696- // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
697- // `isBlockedToDotShortcut` and ` isMfmaToDotShortcut`) once they're fully
698- // subsumed by the linear-layout checks.
718+ // TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
719+ // `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
720+ // checks.
699721 // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
700722 // supported yet in Triton's backend.
701723 return !cvtReordersRegisters (srcTy, dstTy) &&
702724 !isBlockedToDotShortcut (srcTy, dstTy) &&
703- !isMmaToDotShortcut (srcTy, dstTy) &&
704- !isMfmaToDotShortcut (srcTy, dstTy);
725+ !matchMmaV3AndDotOperandLayout (srcTy, dstTy);
705726}
706727
707728bool atomicNeedsSharedMemory (Value value) {
@@ -711,20 +732,6 @@ bool atomicNeedsSharedMemory(Value value) {
711732 return true ;
712733}
713734
714- bool isMmaToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
715- if (matchMmaV3AndDotOperandLayout (srcTy, dstTy))
716- return true ;
717- // dot_op<opIdx=0, parent=#mma> = #mma
718- // when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
719- auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding ());
720- auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
721- return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor () == 2 &&
722- mmaLayout.getWarpsPerCTA ()[1 ] == 1 &&
723- dotOperandLayout.getOpIdx () == 0 &&
724- dotOperandLayout.getParent () == mmaLayout &&
725- !srcTy.getElementType ().isF32 ();
726- }
727-
728735namespace {
729736
730737// / A data structure similar to SetVector but maintains
0 commit comments