@@ -543,7 +543,7 @@ bool supportMMA(Value value, int version) {
543543 (elemTy.isInteger (8 ) && version >= 2 );
544544}
545545
546- bool isBlockedToDotShortcut (RankedTensorType & srcTy, RankedTensorType & dstTy) {
546+ bool isBlockedToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
547547 auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding ());
548548 auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
549549 if (blockedLayout == nullptr || dotOperandLayout == nullptr )
@@ -647,57 +647,94 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
647647 return ans;
648648}
649649
650- bool cvtReordersRegisters (RankedTensorType srcTy, RankedTensorType dstTy) {
650+ // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
651+ // under kBlock, kWarp or kLane (in that order). The idea here is that if we
652+ // have a transformation that's the identity on kBlock, we don't need to use
653+ // distributed shared memory. If it's also the identity on kWarp, we can
654+ // transfer via warp-shuffles, and if it's the identity on kLane just have to
655+ // reorder the registers
656+ std::optional<LinearLayout> minimalCvtLayout (RankedTensorType srcTy,
657+ RankedTensorType dstTy) {
651658 MLIRContext *ctx = srcTy.getContext ();
652659 std::optional<LinearLayout> srcLayout =
653660 toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
654661 std::optional<LinearLayout> dstLayout =
655662 toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
656- if (srcLayout.has_value () && dstLayout.has_value ()) {
657- // comp describes the layout function for converting from src to dst.
658- LinearLayout comp = srcLayout->invertAndCompose (*dstLayout);
659- StringAttr kLane = StringAttr::get (ctx, " lane" );
660- StringAttr kWarp = StringAttr::get (ctx, " warp" );
661- StringAttr kBlock = StringAttr::get (ctx, " block" );
662- // TODO(jlebar): These checks are overly-restrictive. For example, we can
663- // transfer by shuffling registers (case 1) if and only if all of the bases
664- // for `register` have 0s for lane, warp, and block. But the check below is
665- // stronger than this, checking also that the choice of lane/warp/block does
666- // not affect the permutation of registers. If we allow different
667- // lane/warp/blocks to have different permutations, we can generalize this.
668- if (comp.divideRight (LinearLayout::identity1D (comp.getInDimSize (kLane ),
669- kLane , kLane ) *
670- LinearLayout::identity1D (comp.getInDimSize (kWarp ),
671- kWarp , kWarp ) *
672- LinearLayout::identity1D (comp.getInDimSize (kBlock ),
673- kBlock , kBlock ))
674- .has_value ()) {
675- return true ;
663+ if (!(srcLayout.has_value () && dstLayout.has_value ()))
664+ return std::nullopt ;
665+ StringAttr kRegister = StringAttr::get (ctx, " register" );
666+ StringAttr kLane = StringAttr::get (ctx, " lane" );
667+ StringAttr kWarp = StringAttr::get (ctx, " warp" );
668+ StringAttr kBlock = StringAttr::get (ctx, " block" );
669+ auto numSrcRegs = srcLayout->getInDimSize (kRegister );
670+ auto numDstRegs = dstLayout->getInDimSize (kRegister );
671+ // The `invertAndCompose` function will generate a layout that is injective
672+ // by assigning new output dimensions to free variables. For instance,
673+ // consider a scenario where `srcLayout` has a free variable in the lane
674+ // dimension, while `dstLayout` has two free variables in the lane
675+ // dimension and also a larger number of registers.
676+ // The injective form of `srcLayout` will add only a single additional row
677+ // to the transformation matrix, whereas the injective form of `dstLayout`
678+ // will add two additional rows. This discrepancy causes misleading results
679+ // because the matrices end up with a different number of rows.
680+ //
681+ // Take `dstLayout ⋅ srcLayout^-1` as an example:
682+ //
683+ // - `injective(dstLayout)`: [n, m] → [n + 2, m]
684+ // - `injective(srcLayout)`: [n, m] → [n + 1, m]
685+ // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
686+ // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
687+ // 1] → [n + 2, n + 1]
688+ //
689+ // Here, the `(n + 1)`-th row added by `dstLayout` represents the free
690+ // variable in registers, and the `(n + 2)`-th row represents the free
691+ // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
692+ // represents the free variable in lanes. As a result, the `(n + 1)`-th row
693+ // in two layouts do not correspond to the same free variable.
694+ //
695+ // To address this issue, we pad the free variables in `srcLayout` and
696+ // `dstLayout` to ensure they have the same number of registers. This
697+ // guarantees that the resulting matrices have the same number of rows,
698+ // ensuring consistency in the composition process.
699+ auto numRegs = std::max (numSrcRegs, numDstRegs);
700+ auto srcLayoutWithFreeRegs = srcLayout->resize (kRegister , numRegs);
701+ auto dstLayoutWithFreeRegs = dstLayout->resize (kRegister , numRegs);
702+ // comp describes the layout function to create dst from src.
703+ LinearLayout comp =
704+ dstLayoutWithFreeRegs.invertAndCompose (srcLayoutWithFreeRegs);
705+ // We try to quotient by the largest subspace first
706+ auto dims = SmallVector<StringRef>{" block" , " warp" , " lane" , " register" };
707+ for (auto dim : dims) {
708+ auto quotient = comp.quotient (StringAttr::get (ctx, dim));
709+ if (!quotient.has_value ()) {
710+ break ;
676711 }
712+ comp = *quotient;
677713 }
678- return false ;
714+ return comp;
715+ }
716+
717+ bool cvtReordersRegisters (RankedTensorType srcTy, RankedTensorType dstTy) {
718+ auto layout = minimalCvtLayout (srcTy, dstTy);
719+ MLIRContext *ctx = srcTy.getContext ();
720+ if (!layout.has_value ()) {
721+ return false ;
722+ }
723+ auto kRegister = StringAttr::get (ctx, " register" );
724+ auto outDims = llvm::to_vector (layout->getOutDimNames ());
725+ return outDims.empty () || ArrayRef (outDims) == ArrayRef ({kRegister });
679726}
680727
681728bool cvtNeedsWarpShuffle (RankedTensorType srcTy, RankedTensorType dstTy) {
729+ auto layout = minimalCvtLayout (srcTy, dstTy);
682730 MLIRContext *ctx = srcTy.getContext ();
683- std::optional<LinearLayout> srcLayout =
684- toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
685- std::optional<LinearLayout> dstLayout =
686- toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
687- if (srcLayout.has_value () && dstLayout.has_value ()) {
688- // comp describes the layout function for converting from src to dst.
689- LinearLayout comp = srcLayout->invertAndCompose (*dstLayout);
690- StringAttr kWarp = StringAttr::get (ctx, " warp" );
691- StringAttr kBlock = StringAttr::get (ctx, " block" );
692- if (comp.divideRight (LinearLayout::identity1D (comp.getInDimSize (kWarp ),
693- kWarp , kWarp ) *
694- LinearLayout::identity1D (comp.getInDimSize (kBlock ),
695- kBlock , kBlock ))
696- .has_value ()) {
697- return true ;
698- }
731+ if (!layout.has_value ()) {
732+ return false ;
699733 }
700- return false ;
734+ auto kRegister = StringAttr::get (ctx, " register" );
735+ auto kLane = StringAttr::get (ctx, " lane" );
736+ return llvm::to_vector (layout->getOutDimNames ()) ==
737+ llvm::SmallVector<StringAttr, 2 >{kRegister , kLane };
701738}
702739
703740bool cvtNeedsSharedMemory (RankedTensorType srcTy, RankedTensorType dstTy) {
0 commit comments