@@ -396,10 +396,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
396396 int nIndex = 1 + hasBatchDim;
397397 (void )mIndex , (void )nIndex;
398398
399- assert (((getMDim () == 32 && getNDim () == 32 ) ||
400- (getMDim () == 16 && getNDim () == 16 )) &&
401- " Unsupported mfma type" );
402-
403399 MLIRContext *ctx = getContext ();
404400 SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
405401
@@ -411,85 +407,56 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
411407 // register and lane layout for mfma instructions.
412408
413409 // We use the order from fastest varying to slowest varying. So each base
414- // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices.
410+ // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices,
411+ // which will be [1, 0] / [2, 1, 0].
415412 SmallVector<unsigned > order = getDefaultMmaOrder (*this );
416- auto tileLayout = LinearLayout::empty ();
417-
418- if (getMDim () == 32 ) {
419- // For mfma with 32x32 output, each of the 64 threads holds 16 elements.
420- //
421- // For the register (i.e., element) dimension, these 16 elements are along
422- // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows
423- // and then the next 4 rows being a gap.
424- //
425- // For the lane (i.e., thread) dimension, these threads are along the
426- // matrix C's N dimension, with 32 consecutive threads covering a whole
427- // row and the next 32 threads start after a gap spanning 4 rows.
428- tileLayout = LinearLayout (
429- {{kRegister , {{0 , 1 }, {0 , 2 }, {0 , 8 }, /* gap*/ {0 , 16 }}},
430- {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, {16 , 0 }, /* gap*/ {0 , 4 }}}},
431- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
432- // For mfma.transposed layout, the element ownership among threads are
433- // "transposed" within each warp.
434- if (getIsTransposed ())
435- tileLayout = LinearLayout (
436- {{kRegister , {{1 , 0 }, {2 , 0 }, {8 , 0 }, /* gap*/ {16 , 0 }}},
437- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, /* gap*/ {4 , 0 }}}},
438- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
413+ auto dimM = outDimNames[order[1 ]];
414+ auto dimN = outDimNames[order[0 ]];
415+
416+ unsigned mDim = getMDim ();
417+ unsigned nDim = getNDim ();
418+ auto elementType = getElementType ();
419+ int height = (elementType && elementType->isF64 ()) ? 1 : 4 ;
420+ constexpr int warpSize = 64 ;
421+
422+ bool isTransposed = getIsTransposed ();
423+ // Special case for 64x4 mfma: we always transpose the output to turn
424+ // the 64x4 mfma into a equalvalent 4x64 mfma and swap operand A and B, so
425+ // that we can use the mfma broadcast.
426+ if (mDim == 64 && nDim == 4 )
427+ assert (isTransposed && " 64x4 mfma must be transposed" );
428+
429+ int tiles = (mDim * nDim) / (warpSize * height);
430+
431+ LinearLayout tileLayout = LinearLayout::empty ();
432+ if (!isTransposed) {
433+ // Each lane holds 'height' elements along the M dimension.
434+ LinearLayout regs = LinearLayout::identity1D (height, kRegister , dimM);
435+ // First, distribute the lanes along the N dimension.
436+ // Then, distribute the lanes along the M dimension. If the #elements
437+ // exceeds the mDim, duplicate elements across lanes - this can happen for
438+ // 4x4 output.
439+ LinearLayout lanes = LinearLayout::identity1D (nDim, kLane , dimN) *
440+ LinearLayout::identity1D (warpSize / nDim, kLane , dimM);
441+ tileLayout = (regs * lanes);
442+
443+ // Repeat the above distribution along the M dimension to fits the tile.
444+ if (tiles > 0 )
445+ tileLayout *= LinearLayout::identity1D (tiles, kRegister , dimM);
439446 } else {
440- assert (getMDim () == 16 );
441- auto elementType = getElementType ();
442- if (!(elementType && elementType->isF64 ())) {
443- // For mfma with 16x16 output (<= 32 bits), each of the 64 threads holds 4
444- // elements.
445- //
446- // For the register (i.e., element) dimension, these 4 elements are along
447- // the matrix C's M dimension, with 4 consecutive elements spanning 4
448- // rows.
449- //
450- // For the lane (i.e., thread) dimension, these threads are along the
451- // matrix C's N dimension, with 16 consecutive threads covering a whole
452- // row and the next 16 threads start after a gap spanning 4 rows.
453- tileLayout = LinearLayout (
454- {{kRegister , {{0 , 1 }, {0 , 2 }}},
455- {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, /* gap*/ {0 , 4 }, {0 , 8 }}}},
456- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
457- // For mfma.transposed layout, the element ownership among threads are
458- // "transposed" within each warp.
459- if (getIsTransposed ())
460- tileLayout = LinearLayout (
461- {{kRegister , {{1 , 0 }, {2 , 0 }}},
462- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, /* gap*/ {4 , 0 }, {8 , 0 }}}},
463- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
447+ // For the transposed output, we will use the same method for layout but
448+ // swap the order of the M and N dimensions.
449+ LinearLayout regs = LinearLayout::identity1D (height, kRegister , dimN);
450+ LinearLayout lanes = LinearLayout::identity1D (mDim , kLane , dimM) *
451+ LinearLayout::identity1D (warpSize / mDim , kLane , dimN);
452+ tileLayout = (regs * lanes);
464453
465- } else {
466- // For 64 bit mfma with 16x16 output, each of the 64 threads holds 4
467- // elements across 8 VGPRs. each 64 bit element is split across pairs of 2
468- // VGPRs each. The first VGPR holds the first 32 bits and second holding
469- // the last 32 bits.
470- //
471- // For the register (i.e., element) dimension, these 4 elements are along
472- // the matrix C's M dimension, with 4 consecutive elements spanning 4
473- // rows.
474- //
475- // For the lane (i.e., thread) dimension, these threads are along the
476- // matrix C's N dimension, with each group of 16 consecutive threads
477- // covering a whole adjacent row. Unlike the <=32 bit cases, there's no
478- // row gaps between the groups.
479- tileLayout = LinearLayout (
480- {{kRegister , {{0 , 4 }, {0 , 8 }}},
481- {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, {0 , 1 }, {0 , 2 }}}},
482- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
483- // For mfma.transposed layout, the element ownership among threads are
484- // "transposed" within each warp.
485- if (getIsTransposed ())
486- tileLayout = LinearLayout (
487- {{kRegister , {{4 , 0 }, {8 , 0 }}},
488- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {1 , 0 }, {2 , 0 }}}},
489- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
490- }
454+ if (tiles > 0 )
455+ tileLayout *= LinearLayout::identity1D (tiles, kRegister , dimN);
491456 }
492457
458+ tileLayout = tileLayout.transposeOuts ({dimN, dimM});
459+
493460 // Instead of defining the layout on a CTA tile and using the
494461 // combineCtaCgaWithShape function to extend it to the whole tensor, we take a
495462 // different approach. Suppose tilesPerWarp is 2x2—meaning a warp computes a
@@ -511,8 +478,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
511478 const unsigned tilesPerWarpN = tilesPerWarp[nIndex];
512479 const unsigned warpsPerCTAM = warpsPerCTA[mIndex ];
513480 const unsigned warpsPerCTAN = warpsPerCTA[nIndex];
514- const auto &dimN = outDimNames[order[0 ]];
515- const auto &dimM = outDimNames[order[1 ]];
516481
517482 // First, extend the layout along the N dimension:
518483 // - registers are distributed across tilesPerWarpN
@@ -766,11 +731,19 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
766731 int mIndex = 0 + hasBatchDim;
767732
768733 int32_t kWidth = dotMfmaLayout.getKWidth ();
769- auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
770- int32_t kSize = shape[ kDim ];
734+ auto kDimIndex = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
735+
771736 auto warpsPerCTA = mfmaLayout.getWarpsPerCTA ();
772737 auto tilesPerWarp = mfmaLayout.getTilesPerWarp ();
773- auto tilePerWarpNonK = tilesPerWarp[kDim ];
738+ auto tilePerWarpNonK = tilesPerWarp[kDimIndex ];
739+
740+ auto mDim = mfmaLayout.getMDim ();
741+ auto nDim = mfmaLayout.getNDim ();
742+ auto opIdx = dotMfmaLayout.getOpIdx ();
743+ auto nonKDim = opIdx == 0 ? mDim : nDim;
744+ constexpr int warpSize = 64 ;
745+
746+ int32_t kSize = shape[kDimIndex ];
774747
775748 MLIRContext *ctx = dotMfmaLayout.getContext ();
776749 SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
@@ -785,57 +758,43 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
785758 // for both cases it is [k, nonk]/[k, nonk, batch]
786759 auto order =
787760 getOrderForDotOperand (dotMfmaLayout.getOpIdx (), rank, /* kContig*/ true );
761+ auto dimK = outDimNames[order[0 ]];
762+ auto dimNonK = outDimNames[order[1 ]];
788763
789764 // warp order
790765 // common for both operand A and B: [0, 1] / [0, 1, 2]
791766 // in both cases it is [M dim, N dim]/[batch, M dim, N dim]
792767 auto warpOrder = getDefaultMmaOrder (mfmaLayout);
793768
794- // Lane holds kWidth consecutive elements along k dimension, so
795- // base register vectors for one tile are initialized in following way:
796- // {1, 0}, {2, 0} ... {kWidth/2, 0}
797- std::vector<std::vector<int32_t >> registerBase;
798- for (int32_t elem = 1 ; elem < kWidth ; elem *= 2 )
799- registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
800-
801- std::vector<std::vector<int32_t >> laneBase;
802- int32_t kTileSize = -1 ;
803-
804- if (mfmaLayout.getMDim () == 32 ) {
805- // Canonical MFMA linear layout handles 4 consecutive elements along
806- // the register dimension. Dot operand handles variable kWidth consecutive
807- // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
808- // 32}, this means that mapping of first 5 base (up to thread 16) vectors
809- // will be an identity along N dim. Thread 32 will be mapped to element
810- // kWidth in K dimension.
811- laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {kWidth , 0 }};
812- kTileSize = kWidth * 2 ;
813- } else {
814- assert (mfmaLayout.getMDim () == 16 );
815- // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
816- // means that mapping of first 4 base (up to thread 16) vectors will be an
817- // identity along N dim. Thread 16 will be mapped to element kWisth in K
818- // dimension. Thread 32 is mapped to element 2*kWidth in K dim.
819- laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {kWidth , 0 }, {kWidth * 2 , 0 }};
820- kTileSize = kWidth * 4 ;
821- }
822- assert (kTileSize != -1 );
823- // Add repeats of registers along K dimension to register base vectors
824- for (int32_t elem = kTileSize ; elem < kSize ; elem *= 2 )
825- registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
769+ // Each lane holds kWidth elements along the K dimension
770+ LinearLayout regs = LinearLayout::identity1D (kWidth , kRegister , dimK);
771+ // First distribute nonKDim elements along the non-K dimension,
772+ // then distribute remaining elements along the K dimension
773+ LinearLayout lanes =
774+ LinearLayout::identity1D (nonKDim, kLane , dimNonK) *
775+ LinearLayout::identity1D (warpSize / nonKDim, kLane , dimK);
776+ LinearLayout tileLayout = regs * lanes;
826777
827- // Base vectors above are defined in a fixed order [k-dim, non-k-dim].
828- // To assign them to actual matrix dimensions we assoicate with register
829- // `order` which is also also [k, nonk].
778+ int kTileSize = warpSize / nonKDim * kWidth ;
779+ // Special case for 4x64 and 64x4 mfma: for the 64x64 operand,
780+ // we need to repeat the layout 16 times along the K dimension
781+ if ((mDim == 64 && nDim == 4 && opIdx == 0 ) ||
782+ (mDim == 4 && nDim == 64 && opIdx == 1 )) {
783+ tileLayout *= LinearLayout::identity1D (16 , kRegister , dimK);
784+ kTileSize *= 16 ;
785+ }
830786
831- // Add repeats of registers along non-K dimension to register base vectors
832- for (int32_t elem = mfmaLayout.getMDim ();
833- elem < tilePerWarpNonK * mfmaLayout.getMDim (); elem *= 2 )
834- registerBase.emplace_back (std::vector<int32_t >{0 , elem});
787+ // If shape K is larger than the tile size, repeat the tile
788+ // along the K dimension.
789+ if (kSize > kTileSize ) {
790+ tileLayout *= LinearLayout::identity1D (kSize / kTileSize , kRegister , dimK);
791+ }
835792
836- LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
837- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
793+ // Follow the tiles per warp property, repeat the tile layout
794+ // along the non-K dimension.
795+ tileLayout *= LinearLayout::identity1D (tilePerWarpNonK, kRegister , dimNonK);
838796
797+ tileLayout = tileLayout.transposeOuts ({dimK, dimNonK});
839798 if (hasBatchDim) {
840799 assert (order[2 ] == 0 );
841800 // Extend the base vector with one value to accommodate for the batch
@@ -845,11 +804,14 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
845804 }
846805
847806 LinearLayout warpLayout = identityStandardND (kWarp , warpsPerCTA, warpOrder);
848-
849- LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
850- warpLayout.transposeOuts (outDimNames);
851-
852- return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
807+ LinearLayout ctaLayout = tileLayout * warpLayout;
808+
809+ // Note the current the output order is [k, nonk]/[k, nonk, batch]. If the
810+ // layout's out-size is smaller than the shape, we follow this order to
811+ // extend each dimension to match the shape. After that, we can transpose
812+ // to match the standard output order.
813+ return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape)
814+ .transposeOuts (outDimNames);
853815}
854816
855817LinearLayout
0 commit comments