@@ -472,10 +472,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
472472 int nIndex = 1 + hasBatchDim;
473473 (void )mIndex , (void )nIndex;
474474
475- assert (((shape[mIndex ] == 1 || shape[mIndex ] >= getMDim ()) &&
476- (shape[nIndex] == 1 || shape[nIndex] >= getNDim ())) &&
477- " Unsupported tensor shape for given mfma layout" );
478-
479475 assert (((getMDim () == 32 && getNDim () == 32 ) ||
480476 (getMDim () == 16 && getNDim () == 16 )) &&
481477 " Unsupported mfma type" );
@@ -579,55 +575,76 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
579575 // 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
580576 // held by exactly one thread, maintaining the same number of global loads
581577 // as in a blocked layout.
578+ //
579+ // Other use of Linear layout is a support of rare corner cases,
580+ // for example one instruction tile is larger than tensor
582581 auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
583582
584- if (dotMfmaLayout.getOpIdx () == 0 ) {
585- return std::nullopt ;
586- }
587583 auto rank = shape.size ();
588584 bool hasBatchDim = rank == 3 ;
589585 int mIndex = 0 + hasBatchDim;
590586
591- auto kWidth = dotMfmaLayout.getKWidth ();
587+ int32_t kWidth = dotMfmaLayout.getKWidth ();
588+ auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
589+ int32_t kSize = shape[kDim ];
592590 auto warpsPerCTA = mfmaLayout.getWarpsPerCTA ();
593591
594- if (kWidth != 8 || warpsPerCTA[mIndex ] != 1 ) {
595- return std::nullopt ;
596- }
597-
598592 MLIRContext *ctx = dotMfmaLayout.getContext ();
599593 SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
600594
601595 StringAttr kRegister = S (" register" );
602596 StringAttr kLane = S (" lane" );
597+ StringAttr kWarp = S (" warp" );
603598
599+ // register order
600+ // operand A: [1, 0] / [2, 1, 0]
601+ // operand B: [0, 1] / [1, 2, 0]
602+ // for both cases it is [k, nonk]/[k, nonk, batch]
604603 SmallVector<unsigned > order = triton::gpu::getOrder (dotMfmaLayout);
605- auto tileLayout = LinearLayout::empty ();
604+ // warp order
605+ // common for both operand A and B: [0, 1] / [0, 1, 2]
606+ // in both cases it is [M dim, N dim]/[batch, M dim, N dim]
607+ SmallVector<unsigned > warpOrder = triton::gpu::getWarpOrder (dotMfmaLayout);
608+
609+ // Lane holds kWidth consecutive elements along k dimension, so
610+ // base register vectors for one tile are initialized in following way:
611+ // {1, 0}, {2, 0} ... {kWidth/2, 0}
612+ std::vector<std::vector<int32_t >> registerBase;
613+ for (int32_t elem = 1 ; elem < kWidth ; elem *= 2 )
614+ registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
615+
616+ std::vector<std::vector<int32_t >> laneBase;
617+ int32_t kTileSize = -1 ;
606618
607619 if (mfmaLayout.getMDim () == 32 ) {
608- // Based on canonical MFMA linear layout, which handles 4 consecutive
609- // elements along the register dimension, kWidth=8 means we have 8
610- // consecutive elements, so we have an additional {4, 0} base vector here.
611- // For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this
612- // means that mapping of first 5 base (up to thread 16) vectors will be an
613- // identity along N dim. Thread 32 will be mapped to element 8 in K
614- // dimension, because kWidth == 8.
615- tileLayout = LinearLayout (
616- {{kRegister , {{1 , 0 }, {2 , 0 }, {4 , 0 }}},
617- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {8 , 0 }}}},
618- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
620+ // Canonical MFMA linear layout handles 4 consecutive elements along
621+ // the register dimension. Dot operand handles varaible kWidth consecutive
622+ // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
623+ // 32}, this means that mapping of first 5 base (up to thread 16) vectors
624+ // will be an identity along N dim. Thread 32 will be mapped to element
625+ // kWidth in K dimension.
626+ laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {kWidth , 0 }};
627+ kTileSize = kWidth * 2 ;
619628 } else {
620629 assert (mfmaLayout.getMDim () == 16 );
621630 // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
622631 // means that mapping of first 4 base (up to thread 16) vectors will be an
623- // identity along N dim. Thread 16 will be mapped to element 8 in K
624- // dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that
625- // is 2*kWidth in K dim.
626- tileLayout = LinearLayout (
627- {{kRegister , {{1 , 0 }, {2 , 0 }, {4 , 0 }}},
628- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {8 , 0 }, {16 , 0 }}}},
629- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
632+ // identity along N dim. Thread 16 will be mapped to element kWisth in K
633+ // dimension. Thread 32 is mapped to element 2*kWidth in K dim.
634+ laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {kWidth , 0 }, {kWidth * 2 , 0 }};
635+ kTileSize = kWidth * 4 ;
630636 }
637+ assert (kTileSize != -1 );
638+ // Add repeats of registers along K dimension to register base vectors
639+ for (int32_t elem = kTileSize ; elem < kSize ; elem *= 2 )
640+ registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
641+
642+ // Base vectors above are defined in a fixed order [non-k-dim, k-dim].
643+ // To assign them to actual matrix dimensions `order` array is used.
644+ // For operand A: non-k-dim -> dim0, k-dim -> dim1
645+ // For operand B: non-k-dim -> dim1, k-dim -> dim0
646+ LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
647+ {outDimNames[order[0 ]], outDimNames[order[1 ]]});
631648
632649 if (hasBatchDim) {
633650 assert (order[2 ] == 0 );
@@ -638,8 +655,10 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
638655 }
639656
640657 LinearLayout warpLayout =
641- identityND (S (" warp" ), warpsPerCTA, order, outDimNames);
642- LinearLayout ctaLayout = tileLayout * warpLayout;
658+ identityND (kWarp , warpsPerCTA, warpOrder, outDimNames);
659+
660+ LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
661+ warpLayout.transposeOuts (outDimNames);
643662
644663 return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
645664}
0 commit comments