@@ -473,10 +473,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
473473 int nIndex = 1 + hasBatchDim;
474474 (void )mIndex , (void )nIndex;
475475
476- assert (((shape[mIndex ] == 1 || shape[mIndex ] >= getMDim ()) &&
477- (shape[nIndex] == 1 || shape[nIndex] >= getNDim ())) &&
478- " Unsupported tensor shape for given mfma layout" );
479-
480476 assert (((getMDim () == 32 && getNDim () == 32 ) ||
481477 (getMDim () == 16 && getNDim () == 16 )) &&
482478 " Unsupported mfma type" );
@@ -580,55 +576,76 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
580576 // 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
581577 // held by exactly one thread, maintaining the same number of global loads
582578 // as in a blocked layout.
579+ //
580+ // Other use of Linear layout is a support of rare corner cases,
581+ // for example one instruction tile is larger than tensor
583582 auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
584583
585- if (dotMfmaLayout.getOpIdx () == 0 ) {
586- return std::nullopt ;
587- }
588584 auto rank = shape.size ();
589585 bool hasBatchDim = rank == 3 ;
590586 int mIndex = 0 + hasBatchDim;
591587
592- auto kWidth = dotMfmaLayout.getKWidth ();
588+ int32_t kWidth = dotMfmaLayout.getKWidth ();
589+ auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
590+ int32_t kSize = shape[kDim ];
593591 auto warpsPerCTA = mfmaLayout.getWarpsPerCTA ();
594592
595- if (kWidth != 8 || warpsPerCTA[mIndex ] != 1 ) {
596- return std::nullopt ;
597- }
598-
599593 MLIRContext *ctx = dotMfmaLayout.getContext ();
600594 SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
601595
602596 StringAttr kRegister = S (" register" );
603597 StringAttr kLane = S (" lane" );
598+ StringAttr kWarp = S (" warp" );
604599
600+ // register order
601+ // operand A: [1, 0] / [2, 1, 0]
602+ // operand B: [0, 1] / [1, 2, 0]
603+ // for both cases it is [k, nonk]/[k, nonk, batch]
605604 SmallVector<unsigned > order = triton::gpu::getOrder (dotMfmaLayout);
606- auto tileLayout = LinearLayout::empty ();
605+ // warp order
606+ // common for both operand A and B: [0, 1] / [0, 1, 2]
607+ // in both cases it is [M dim, N dim]/[batch, M dim, N dim]
608+ SmallVector<unsigned > warpOrder = triton::gpu::getWarpOrder (dotMfmaLayout);
609+
610+ // Lane holds kWidth consecutive elements along k dimension, so
611+ // base register vectors for one tile are initialized in following way:
612+ // {1, 0}, {2, 0} ... {kWidth/2, 0}
613+ std::vector<std::vector<int32_t >> registerBase;
614+ for (int32_t elem = 1 ; elem < kWidth ; elem *= 2 )
615+ registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
616+
617+ std::vector<std::vector<int32_t >> laneBase;
618+ int32_t kTileSize = -1 ;
607619
608620 if (mfmaLayout.getMDim () == 32 ) {
609- // Based on canonical MFMA linear layout, which handles 4 consecutive
610- // elements along the register dimension, kWidth=8 means we have 8
611- // consecutive elements, so we have an additional {4, 0} base vector here.
612- // For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this
613- // means that mapping of first 5 base (up to thread 16) vectors will be an
614- // identity along N dim. Thread 32 will be mapped to element 8 in K
615- // dimension, because kWidth == 8.
616- tileLayout = LinearLayout (
617- {{kRegister , {{1 , 0 }, {2 , 0 }, {4 , 0 }}},
618- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {8 , 0 }}}},
619- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
621+ // Canonical MFMA linear layout handles 4 consecutive elements along
622+ // the register dimension. Dot operand handles varaible kWidth consecutive
623+ // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
624+ // 32}, this means that mapping of first 5 base (up to thread 16) vectors
625+ // will be an identity along N dim. Thread 32 will be mapped to element
626+ // kWidth in K dimension.
627+ laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {kWidth , 0 }};
628+ kTileSize = kWidth * 2 ;
620629 } else {
621630 assert (mfmaLayout.getMDim () == 16 );
622631 // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
623632 // means that mapping of first 4 base (up to thread 16) vectors will be an
624- // identity along N dim. Thread 16 will be mapped to element 8 in K
625- // dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that
626- // is 2*kWidth in K dim.
627- tileLayout = LinearLayout (
628- {{kRegister , {{1 , 0 }, {2 , 0 }, {4 , 0 }}},
629- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {8 , 0 }, {16 , 0 }}}},
630- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
633+ // identity along N dim. Thread 16 will be mapped to element kWisth in K
634+ // dimension. Thread 32 is mapped to element 2*kWidth in K dim.
635+ laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {kWidth , 0 }, {kWidth * 2 , 0 }};
636+ kTileSize = kWidth * 4 ;
631637 }
638+ assert (kTileSize != -1 );
639+ // Add repeats of registers along K dimension to register base vectors
640+ for (int32_t elem = kTileSize ; elem < kSize ; elem *= 2 )
641+ registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
642+
643+ // Base vectors above are defined in a fixed order [non-k-dim, k-dim].
644+ // To assign them to actual matrix dimensions `order` array is used.
645+ // For operand A: non-k-dim -> dim0, k-dim -> dim1
646+ // For operand B: non-k-dim -> dim1, k-dim -> dim0
647+ LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
648+ {outDimNames[order[0 ]], outDimNames[order[1 ]]});
632649
633650 if (hasBatchDim) {
634651 assert (order[2 ] == 0 );
@@ -639,8 +656,10 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
639656 }
640657
641658 LinearLayout warpLayout =
642- identityND (S (" warp" ), warpsPerCTA, order, outDimNames);
643- LinearLayout ctaLayout = tileLayout * warpLayout;
659+ identityND (kWarp , warpsPerCTA, warpOrder, outDimNames);
660+
661+ LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
662+ warpLayout.transposeOuts (outDimNames);
644663
645664 return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
646665}
@@ -1001,6 +1020,8 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
10011020 mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding ());
10021021 if (!mmaLayout || !mmaLayout.isHopper ())
10031022 return false ;
1023+ if (isa<PointerType>(tensorTy.getElementType ()))
1024+ return false ;
10041025 if (tensorTy.getElementType ().getIntOrFloatBitWidth () != 16 )
10051026 return false ;
10061027 if (order[0 ] != 1 )
0 commit comments