@@ -34,9 +34,9 @@ int getParentAxis(Attribute layout, int axis) {
3434 return axis;
3535}
3636
37- SmallVector<unsigned > getParentOrder (Attribute layout) {
37+ SmallVector<unsigned > getParentThreadOrder (Attribute layout) {
3838 if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
39- return getParentOrder (sliceEncoding.getParent ());
39+ return getParentThreadOrder (sliceEncoding.getParent ());
4040 }
4141 return getThreadOrder (layout);
4242}
@@ -46,12 +46,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
4646// TODO(jlebar): Move this class into namespace triton.
4747bool ReduceOpHelper::isReductionOnLayoutFastAxis () {
4848 return getParentAxis (getSrcLayout (), axis) ==
49- getParentOrder (getSrcLayout ())[0 ];
49+ getParentThreadOrder (getSrcLayout ())[0 ];
5050}
5151
52- SmallVector<unsigned > ReduceOpHelper::getOrderWithAxisAtBeginning () {
52+ SmallVector<unsigned > ReduceOpHelper::getThreadOrderWithAxisAtBeginning () {
5353 auto srcLayout = getSrcLayout ();
54- auto order = getOrder (srcLayout);
54+ auto order = getThreadOrder (srcLayout);
5555 auto it = std::find (order.begin (), order.end (), axis);
5656 // delete the axis from order
5757 order.erase (it);
@@ -543,7 +543,7 @@ bool supportMMA(Value value, int version) {
543543 (elemTy.isInteger (8 ) && version >= 2 );
544544}
545545
546- bool isBlockedToDotShortcut (RankedTensorType & srcTy, RankedTensorType & dstTy) {
546+ bool isBlockedToDotShortcut (RankedTensorType srcTy, RankedTensorType dstTy) {
547547 auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding ());
548548 auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ());
549549 if (blockedLayout == nullptr || dotOperandLayout == nullptr )
@@ -646,8 +646,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
646646 toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
647647 if (!(srcLayout.has_value () && dstLayout.has_value ()))
648648 return std::nullopt ;
649+ StringAttr kRegister = StringAttr::get (ctx, " register" );
650+ StringAttr kLane = StringAttr::get (ctx, " lane" );
651+ StringAttr kWarp = StringAttr::get (ctx, " warp" );
652+ StringAttr kBlock = StringAttr::get (ctx, " block" );
653+ auto numSrcRegs = srcLayout->getInDimSize (kRegister );
654+ auto numDstRegs = dstLayout->getInDimSize (kRegister );
655+ // The `invertAndCompose` function will generate a layout that is injective
656+ // by assigning new output dimensions to free variables. For instance,
657+ // consider a scenario where `srcLayout` has a free variable in the lane
658+ // dimension, while `dstLayout` has two free variables in the lane
659+ // dimension and also a larger number of registers.
660+ // The injective form of `srcLayout` will add only a single additional row
661+ // to the transformation matrix, whereas the injective form of `dstLayout`
662+ // will add two additional rows. This discrepancy causes misleading results
663+ // because the matrices end up with a different number of rows.
664+ //
665+ // Take `dstLayout ⋅ srcLayout^-1` as an example:
666+ //
667+ // - `injective(dstLayout)`: [n, m] → [n + 2, m]
668+ // - `injective(srcLayout)`: [n, m] → [n + 1, m]
669+ // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
670+ // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
671+ // 1] → [n + 2, n + 1]
672+ //
673+ // Here, the `(n + 1)`-th row added by `dstLayout` represents the free
674+ // variable in registers, and the `(n + 2)`-th row represents the free
675+ // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
676+ // represents the free variable in lanes. As a result, the `(n + 1)`-th row
677+ // in two layouts do not correspond to the same free variable.
678+ //
679+ // To address this issue, we pad the free variables in `srcLayout` and
680+ // `dstLayout` to ensure they have the same number of registers. This
681+ // guarantees that the resulting matrices have the same number of rows,
682+ // ensuring consistency in the composition process.
683+ auto numRegs = std::max (numSrcRegs, numDstRegs);
684+ auto srcLayoutWithFreeRegs = srcLayout->resize (kRegister , numRegs);
685+ auto dstLayoutWithFreeRegs = dstLayout->resize (kRegister , numRegs);
649686 // comp describes the layout function to create dst from src.
650- LinearLayout comp = dstLayout->invertAndCompose (*srcLayout);
687+ LinearLayout comp =
688+ dstLayoutWithFreeRegs.invertAndCompose (srcLayoutWithFreeRegs);
651689 // We try to quotient by the largest subspace first
652690 auto dims = SmallVector<StringRef>{" block" , " warp" , " lane" , " register" };
653691 for (auto dim : dims) {
0 commit comments