@@ -396,10 +396,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
396
396
int nIndex = 1 + hasBatchDim;
397
397
(void )mIndex , (void )nIndex;
398
398
399
- assert (((getMDim () == 32 && getNDim () == 32 ) ||
400
- (getMDim () == 16 && getNDim () == 16 )) &&
401
- " Unsupported mfma type" );
402
-
403
399
MLIRContext *ctx = getContext ();
404
400
SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
405
401
@@ -411,85 +407,56 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
411
407
// register and lane layout for mfma instructions.
412
408
413
409
// We use the order from fastest varying to slowest varying. So each base
414
- // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices.
410
+ // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices,
411
+ // which will be [1, 0] / [2, 1, 0].
415
412
SmallVector<unsigned > order = getDefaultMmaOrder (*this );
416
- auto tileLayout = LinearLayout::empty ();
417
-
418
- if (getMDim () == 32 ) {
419
- // For mfma with 32x32 output, each of the 64 threads holds 16 elements.
420
- //
421
- // For the register (i.e., element) dimension, these 16 elements are along
422
- // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows
423
- // and then the next 4 rows being a gap.
424
- //
425
- // For the lane (i.e., thread) dimension, these threads are along the
426
- // matrix C's N dimension, with 32 consecutive threads covering a whole
427
- // row and the next 32 threads start after a gap spanning 4 rows.
428
- tileLayout = LinearLayout (
429
- {{kRegister , {{0 , 1 }, {0 , 2 }, {0 , 8 }, /* gap*/ {0 , 16 }}},
430
- {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, {16 , 0 }, /* gap*/ {0 , 4 }}}},
431
- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
432
- // For mfma.transposed layout, the element ownership among threads are
433
- // "transposed" within each warp.
434
- if (getIsTransposed ())
435
- tileLayout = LinearLayout (
436
- {{kRegister , {{1 , 0 }, {2 , 0 }, {8 , 0 }, /* gap*/ {16 , 0 }}},
437
- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, /* gap*/ {4 , 0 }}}},
438
- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
413
+ auto dimM = outDimNames[order[1 ]];
414
+ auto dimN = outDimNames[order[0 ]];
415
+
416
+ unsigned mDim = getMDim ();
417
+ unsigned nDim = getNDim ();
418
+ auto elementType = getElementType ();
419
+ int height = (elementType && elementType->isF64 ()) ? 1 : 4 ;
420
+ constexpr int warpSize = 64 ;
421
+
422
+ bool isTransposed = getIsTransposed ();
423
+ // Special case for 64x4 mfma: we always transpose the output to turn
424
+ // the 64x4 mfma into a equalvalent 4x64 mfma and swap operand A and B, so
425
+ // that we can use the mfma broadcast.
426
+ if (mDim == 64 && nDim == 4 )
427
+ assert (isTransposed && " 64x4 mfma must be transposed" );
428
+
429
+ int tiles = (mDim * nDim) / (warpSize * height);
430
+
431
+ LinearLayout tileLayout = LinearLayout::empty ();
432
+ if (!isTransposed) {
433
+ // Each lane holds 'height' elements along the M dimension.
434
+ LinearLayout regs = LinearLayout::identity1D (height, kRegister , dimM);
435
+ // First, distribute the lanes along the N dimension.
436
+ // Then, distribute the lanes along the M dimension. If the #elements
437
+ // exceeds the mDim, duplicate elements across lanes - this can happen for
438
+ // 4x4 output.
439
+ LinearLayout lanes = LinearLayout::identity1D (nDim, kLane , dimN) *
440
+ LinearLayout::identity1D (warpSize / nDim, kLane , dimM);
441
+ tileLayout = (regs * lanes);
442
+
443
+ // Repeat the above distribution along the M dimension to fits the tile.
444
+ if (tiles > 0 )
445
+ tileLayout *= LinearLayout::identity1D (tiles, kRegister , dimM);
439
446
} else {
440
- assert (getMDim () == 16 );
441
- auto elementType = getElementType ();
442
- if (!(elementType && elementType->isF64 ())) {
443
- // For mfma with 16x16 output (<= 32 bits), each of the 64 threads holds 4
444
- // elements.
445
- //
446
- // For the register (i.e., element) dimension, these 4 elements are along
447
- // the matrix C's M dimension, with 4 consecutive elements spanning 4
448
- // rows.
449
- //
450
- // For the lane (i.e., thread) dimension, these threads are along the
451
- // matrix C's N dimension, with 16 consecutive threads covering a whole
452
- // row and the next 16 threads start after a gap spanning 4 rows.
453
- tileLayout = LinearLayout (
454
- {{kRegister , {{0 , 1 }, {0 , 2 }}},
455
- {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, /* gap*/ {0 , 4 }, {0 , 8 }}}},
456
- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
457
- // For mfma.transposed layout, the element ownership among threads are
458
- // "transposed" within each warp.
459
- if (getIsTransposed ())
460
- tileLayout = LinearLayout (
461
- {{kRegister , {{1 , 0 }, {2 , 0 }}},
462
- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, /* gap*/ {4 , 0 }, {8 , 0 }}}},
463
- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
447
+ // For the transposed output, we will use the same method for layout but
448
+ // swap the order of the M and N dimensions.
449
+ LinearLayout regs = LinearLayout::identity1D (height, kRegister , dimN);
450
+ LinearLayout lanes = LinearLayout::identity1D (mDim , kLane , dimM) *
451
+ LinearLayout::identity1D (warpSize / mDim , kLane , dimN);
452
+ tileLayout = (regs * lanes);
464
453
465
- } else {
466
- // For 64 bit mfma with 16x16 output, each of the 64 threads holds 4
467
- // elements across 8 VGPRs. each 64 bit element is split across pairs of 2
468
- // VGPRs each. The first VGPR holds the first 32 bits and second holding
469
- // the last 32 bits.
470
- //
471
- // For the register (i.e., element) dimension, these 4 elements are along
472
- // the matrix C's M dimension, with 4 consecutive elements spanning 4
473
- // rows.
474
- //
475
- // For the lane (i.e., thread) dimension, these threads are along the
476
- // matrix C's N dimension, with each group of 16 consecutive threads
477
- // covering a whole adjacent row. Unlike the <=32 bit cases, there's no
478
- // row gaps between the groups.
479
- tileLayout = LinearLayout (
480
- {{kRegister , {{0 , 4 }, {0 , 8 }}},
481
- {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, {0 , 1 }, {0 , 2 }}}},
482
- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
483
- // For mfma.transposed layout, the element ownership among threads are
484
- // "transposed" within each warp.
485
- if (getIsTransposed ())
486
- tileLayout = LinearLayout (
487
- {{kRegister , {{4 , 0 }, {8 , 0 }}},
488
- {kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {1 , 0 }, {2 , 0 }}}},
489
- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
490
- }
454
+ if (tiles > 0 )
455
+ tileLayout *= LinearLayout::identity1D (tiles, kRegister , dimN);
491
456
}
492
457
458
+ tileLayout = tileLayout.transposeOuts ({dimN, dimM});
459
+
493
460
// Instead of defining the layout on a CTA tile and using the
494
461
// combineCtaCgaWithShape function to extend it to the whole tensor, we take a
495
462
// different approach. Suppose tilesPerWarp is 2x2—meaning a warp computes a
@@ -511,8 +478,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
511
478
const unsigned tilesPerWarpN = tilesPerWarp[nIndex];
512
479
const unsigned warpsPerCTAM = warpsPerCTA[mIndex ];
513
480
const unsigned warpsPerCTAN = warpsPerCTA[nIndex];
514
- const auto &dimN = outDimNames[order[0 ]];
515
- const auto &dimM = outDimNames[order[1 ]];
516
481
517
482
// First, extend the layout along the N dimension:
518
483
// - registers are distributed across tilesPerWarpN
@@ -766,11 +731,19 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
766
731
int mIndex = 0 + hasBatchDim;
767
732
768
733
int32_t kWidth = dotMfmaLayout.getKWidth ();
769
- auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
770
- int32_t kSize = shape[ kDim ];
734
+ auto kDimIndex = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
735
+
771
736
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA ();
772
737
auto tilesPerWarp = mfmaLayout.getTilesPerWarp ();
773
- auto tilePerWarpNonK = tilesPerWarp[kDim ];
738
+ auto tilePerWarpNonK = tilesPerWarp[kDimIndex ];
739
+
740
+ auto mDim = mfmaLayout.getMDim ();
741
+ auto nDim = mfmaLayout.getNDim ();
742
+ auto opIdx = dotMfmaLayout.getOpIdx ();
743
+ auto nonKDim = opIdx == 0 ? mDim : nDim;
744
+ constexpr int warpSize = 64 ;
745
+
746
+ int32_t kSize = shape[kDimIndex ];
774
747
775
748
MLIRContext *ctx = dotMfmaLayout.getContext ();
776
749
SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
@@ -785,57 +758,43 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
785
758
// for both cases it is [k, nonk]/[k, nonk, batch]
786
759
auto order =
787
760
getOrderForDotOperand (dotMfmaLayout.getOpIdx (), rank, /* kContig*/ true );
761
+ auto dimK = outDimNames[order[0 ]];
762
+ auto dimNonK = outDimNames[order[1 ]];
788
763
789
764
// warp order
790
765
// common for both operand A and B: [0, 1] / [0, 1, 2]
791
766
// in both cases it is [M dim, N dim]/[batch, M dim, N dim]
792
767
auto warpOrder = getDefaultMmaOrder (mfmaLayout);
793
768
794
- // Lane holds kWidth consecutive elements along k dimension, so
795
- // base register vectors for one tile are initialized in following way:
796
- // {1, 0}, {2, 0} ... {kWidth/2, 0}
797
- std::vector<std::vector<int32_t >> registerBase;
798
- for (int32_t elem = 1 ; elem < kWidth ; elem *= 2 )
799
- registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
800
-
801
- std::vector<std::vector<int32_t >> laneBase;
802
- int32_t kTileSize = -1 ;
803
-
804
- if (mfmaLayout.getMDim () == 32 ) {
805
- // Canonical MFMA linear layout handles 4 consecutive elements along
806
- // the register dimension. Dot operand handles variable kWidth consecutive
807
- // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
808
- // 32}, this means that mapping of first 5 base (up to thread 16) vectors
809
- // will be an identity along N dim. Thread 32 will be mapped to element
810
- // kWidth in K dimension.
811
- laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {kWidth , 0 }};
812
- kTileSize = kWidth * 2 ;
813
- } else {
814
- assert (mfmaLayout.getMDim () == 16 );
815
- // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
816
- // means that mapping of first 4 base (up to thread 16) vectors will be an
817
- // identity along N dim. Thread 16 will be mapped to element kWisth in K
818
- // dimension. Thread 32 is mapped to element 2*kWidth in K dim.
819
- laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {kWidth , 0 }, {kWidth * 2 , 0 }};
820
- kTileSize = kWidth * 4 ;
821
- }
822
- assert (kTileSize != -1 );
823
- // Add repeats of registers along K dimension to register base vectors
824
- for (int32_t elem = kTileSize ; elem < kSize ; elem *= 2 )
825
- registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
769
+ // Each lane holds kWidth elements along the K dimension
770
+ LinearLayout regs = LinearLayout::identity1D (kWidth , kRegister , dimK);
771
+ // First distribute nonKDim elements along the non-K dimension,
772
+ // then distribute remaining elements along the K dimension
773
+ LinearLayout lanes =
774
+ LinearLayout::identity1D (nonKDim, kLane , dimNonK) *
775
+ LinearLayout::identity1D (warpSize / nonKDim, kLane , dimK);
776
+ LinearLayout tileLayout = regs * lanes;
826
777
827
- // Base vectors above are defined in a fixed order [k-dim, non-k-dim].
828
- // To assign them to actual matrix dimensions we assoicate with register
829
- // `order` which is also also [k, nonk].
778
+ int kTileSize = warpSize / nonKDim * kWidth ;
779
+ // Special case for 4x64 and 64x4 mfma: for the 64x64 operand,
780
+ // we need to repeat the layout 16 times along the K dimension
781
+ if ((mDim == 64 && nDim == 4 && opIdx == 0 ) ||
782
+ (mDim == 4 && nDim == 64 && opIdx == 1 )) {
783
+ tileLayout *= LinearLayout::identity1D (16 , kRegister , dimK);
784
+ kTileSize *= 16 ;
785
+ }
830
786
831
- // Add repeats of registers along non-K dimension to register base vectors
832
- for (int32_t elem = mfmaLayout.getMDim ();
833
- elem < tilePerWarpNonK * mfmaLayout.getMDim (); elem *= 2 )
834
- registerBase.emplace_back (std::vector<int32_t >{0 , elem});
787
+ // If shape K is larger than the tile size, repeat the tile
788
+ // along the K dimension.
789
+ if (kSize > kTileSize ) {
790
+ tileLayout *= LinearLayout::identity1D (kSize / kTileSize , kRegister , dimK);
791
+ }
835
792
836
- LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
837
- {outDimNames[order[0 ]], outDimNames[order[1 ]]});
793
+ // Follow the tiles per warp property, repeat the tile layout
794
+ // along the non-K dimension.
795
+ tileLayout *= LinearLayout::identity1D (tilePerWarpNonK, kRegister , dimNonK);
838
796
797
+ tileLayout = tileLayout.transposeOuts ({dimK, dimNonK});
839
798
if (hasBatchDim) {
840
799
assert (order[2 ] == 0 );
841
800
// Extend the base vector with one value to accommodate for the batch
@@ -845,11 +804,14 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
845
804
}
846
805
847
806
LinearLayout warpLayout = identityStandardND (kWarp , warpsPerCTA, warpOrder);
848
-
849
- LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
850
- warpLayout.transposeOuts (outDimNames);
851
-
852
- return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
807
+ LinearLayout ctaLayout = tileLayout * warpLayout;
808
+
809
+ // Note the current the output order is [k, nonk]/[k, nonk, batch]. If the
810
+ // layout's out-size is smaller than the shape, we follow this order to
811
+ // extend each dimension to match the shape. After that, we can transpose
812
+ // to match the standard output order.
813
+ return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape)
814
+ .transposeOuts (outDimNames);
853
815
}
854
816
855
817
LinearLayout
0 commit comments