@@ -576,16 +576,21 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
576
576
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
577
577
auto mDim = mfmaLayout.getMDim ();
578
578
assert (mDim == 16 || mDim == 32 );
579
+
580
+ bool isFP4 = false ;
581
+ if (elemBitWidth == 4 ) {
582
+ // When doing ds_read_tr4 we actually write the LL as if it were on i8
583
+ // elements this is becasue LL needs to be described for the i8 tensor
584
+ // elements.
585
+ elemBitWidth = 8 ;
586
+ isFP4 = true ;
587
+ }
588
+
579
589
assert (elemBitWidth == 16 || elemBitWidth == 8 );
580
590
581
591
auto rank = shape.size ();
582
592
bool hasBatchDim = rank == 3 ;
583
593
int32_t kWidthDot = dotMfmaLayout.getKWidth ();
584
- // Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit
585
- // loads for most element sizes (16b, 8b, 4b).
586
- const int32_t ldsReadWidth = 64 ;
587
- int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
588
- const int elemByteWidth = elemBitWidth / 8 ;
589
594
auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
590
595
591
596
int32_t kSize = shape[kDim ];
@@ -606,106 +611,151 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
606
611
SmallVector<unsigned > order =
607
612
getOrderForDotOperand (dotMfmaLayout.getOpIdx (), rank, /* kContig*/ false );
608
613
609
- // For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
610
- // of data. The smallest unit for transposition is a
611
- // [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
612
- // where each thread reads kWidthTransRead elements along the non-K dimension.
613
- // Due to the transposition mechanism, each thread ends up with
614
- // kWidthTransRead elements along the K dimension.
615
- //
616
- // The MFMA selection logic prioritizes double-rate MFMA instructions whenever
617
- // possible:
618
- //
619
- // - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
620
- // is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
621
- //
622
- // - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
623
- // selected; otherwise (blockK ≤ k), mfma32x32xk is used.
624
- //
625
- // NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
626
- // instructions are used.
627
- //
628
- // In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
629
- // elements along the K dimension:
630
- // - The first kWidthTransRead elements belong to the first sub-tile.
631
- // - The next kWidthTransRead elements belong to the second sub-tile.
632
- //
633
- // These elements are then grouped into larger tiles, each consisting of
634
- // 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
635
- // for one MFMA instruction. The shape of these tiles depends on the MFMA
636
- // instruction used.
637
- //
638
- // For single-rate MFMA instructions, each thread holds kWidthTransRead
639
- // elements along the K dimension. This means that the larger tile
640
- // (corresponding to one MFMA instruction) consists of 4 {16, kWidthTransRead}
641
- // sub-tiles.
642
614
std::vector<std::vector<int32_t >> registerBase;
643
615
std::vector<std::vector<int32_t >> laneBase;
616
+ auto populateFP4LL = [®isterBase, &laneBase](int kSize , int mDim ) {
617
+ const bool isMfma32 = (mDim == 32 );
618
+ // ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
619
+ // at i8 values for the ownership of register/lane since it's the data type
620
+ // of the tensor. Register dimension: what i8 in the tile are held by thread
621
+ // 0? Lane dimension: what i8 in the tile are held in register 0 of each
622
+ // thread?
623
+ registerBase.push_back ({1 , 0 });
624
+ registerBase.push_back ({2 , 0 });
625
+ registerBase.push_back ({4 , 0 });
626
+ registerBase.push_back ({0 , 16 });
627
+
628
+ // If more than one tile needs to be loaded, populate registerBase
629
+ // dimension for the other tiles
630
+ const int kTileSize = isMfma32 ? 64 : 128 ;
631
+ for (int reg = kTileSize ; reg < kSize ; reg *= 2 ) {
632
+ registerBase.push_back ({0 , reg});
633
+ }
644
634
645
- // Populate register base for first subtile
646
- for (int i = 1 ; i < kWidthTransRead ; i *= 2 ) {
647
- registerBase.push_back ({i, 0 });
648
- }
649
-
650
- const int threadsPerSubtileNonK = 16 / kWidthTransRead ;
651
- const int threadsPerSubtileK = kWidthTransRead ;
652
-
653
- // Populate lane base for first subtile
654
- for (int i = 1 ; i < threadsPerSubtileNonK; i *= 2 ) {
655
- laneBase.push_back ({i * kWidthTransRead , 0 });
656
- }
657
- for (int i = 1 ; i < threadsPerSubtileK; i *= 2 ) {
658
- laneBase.push_back ({0 , i});
659
- }
660
-
661
- // Function to extend register base for multiple tiles K dim.
662
- auto extendRegisterBaseForKDim = [&](int kTileSize , int numSubtilesPerTile) {
663
- const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
664
- int totalRegs = (kSize / kTileSize ) * regsPerTile;
665
-
666
- for (int reg = regsPerTile; reg < totalRegs; reg *= 2 ) {
667
- registerBase.push_back ({0 , (reg / regsPerTile) * kTileSize });
635
+ // When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
636
+ // The LL for the two is different
637
+ laneBase.push_back ({0 , 1 });
638
+ laneBase.push_back ({0 , 2 });
639
+ laneBase.push_back ({0 , 4 });
640
+ laneBase.push_back ({0 , 8 });
641
+ if (mDim == 16 ) {
642
+ laneBase.push_back ({0 , 32 });
643
+ laneBase.push_back ({0 , 64 });
644
+ } else {
645
+ assert (mDim == 32 );
646
+ laneBase.push_back ({8 , 0 });
647
+ laneBase.push_back ({0 , 32 });
668
648
}
669
649
};
650
+ auto populateLL = [®isterBase, &laneBase](int elemBitWidth, int kSize ,
651
+ int kWidthDot , int mDim ) {
652
+ // Number of bits loaded by an LDS read. ds_read_tr primarily supports
653
+ // 64-bit loads for most element sizes (16b, 8b, 4b).
654
+ const int32_t ldsReadWidth = 64 ;
655
+ int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
656
+ const int elemByteWidth = elemBitWidth / 8 ;
657
+ const bool isMfma32 = (mDim == 32 );
658
+
659
+ // For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
660
+ // of data. The smallest unit for transposition is a
661
+ // [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
662
+ // where each thread reads kWidthTransRead elements along the non-K
663
+ // dimension. Due to the transposition mechanism, each thread ends up with
664
+ // kWidthTransRead elements along the K dimension.
665
+ //
666
+ // The MFMA selection logic prioritizes double-rate MFMA instructions
667
+ // whenever possible:
668
+ //
669
+ // - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
670
+ // is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
671
+ //
672
+ // - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
673
+ // selected; otherwise (blockK ≤ k), mfma32x32xk is used.
674
+ //
675
+ // NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
676
+ // instructions are used.
677
+ //
678
+ // In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
679
+ // elements along the K dimension:
680
+ // - The first kWidthTransRead elements belong to the first sub-tile.
681
+ // - The next kWidthTransRead elements belong to the second sub-tile.
682
+ //
683
+ // These elements are then grouped into larger tiles, each consisting of
684
+ // 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
685
+ // for one MFMA instruction. The shape of these tiles depends on the MFMA
686
+ // instruction used.
687
+ //
688
+ // For single-rate MFMA instructions, each thread holds kWidthTransRead
689
+ // elements along the K dimension. This means that the larger tile
690
+ // (corresponding to one MFMA instruction) consists of 4 {16,
691
+ // kWidthTransRead} sub-tiles.
692
+
693
+ // Populate register base for first subtile
694
+ for (int i = 1 ; i < kWidthTransRead ; i *= 2 ) {
695
+ registerBase.push_back ({i, 0 });
696
+ }
670
697
671
- const bool isMfma32 = (mDim == 32 );
672
- const bool isMfma16 = (mDim == 16 );
673
-
674
- // kDoubleTileSize is the k dimension of a tile when double rated
675
- // mfma instructions are used.
676
- const int kDoubleTileSize =
677
- isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
678
- // kTileSize is the actually k dimention of a tile, which is
679
- // determined by kWidthDot.
680
- const int kTileSize = kWidthDot * 64 / mDim ;
681
- // We use kDoubleTileSize as a reference to check whether the given
682
- // kWidthDot leads to double or single sub-tiles in each tile.
683
- const int numSubtilesPerTile = (kTileSize == kDoubleTileSize ) ? 2 : 1 ;
684
-
685
- // Extend register base for large K sizes.
686
- if (numSubtilesPerTile == 2 )
687
- registerBase.push_back ({0 , threadsPerSubtileK}); // Second subtile
688
-
689
- extendRegisterBaseForKDim (kTileSize , numSubtilesPerTile);
698
+ const int threadsPerSubtileNonK = 16 / kWidthTransRead ;
699
+ const int threadsPerSubtileK = kWidthTransRead ;
690
700
691
- // Extend lane base based on MFMA size.
692
- std::vector<std::vector<int32_t >> laneBaseExt;
701
+ // Populate lane base for first subtile
702
+ for (int i = 1 ; i < threadsPerSubtileNonK; i *= 2 ) {
703
+ laneBase.push_back ({i * kWidthTransRead , 0 });
704
+ }
705
+ for (int i = 1 ; i < threadsPerSubtileK; i *= 2 ) {
706
+ laneBase.push_back ({0 , i});
707
+ }
693
708
694
- if (isMfma32) {
695
- laneBaseExt = {{16 , 0 }, {0 , numSubtilesPerTile * threadsPerSubtileK}};
696
- } else {
697
- laneBaseExt = {{0 , numSubtilesPerTile * threadsPerSubtileK},
698
- {0 , 2 * numSubtilesPerTile * threadsPerSubtileK}};
699
- }
709
+ // Function to extend register base for multiple tiles K dim.
710
+ auto extendRegisterBaseForKDim = [&](int kTileSize ,
711
+ int numSubtilesPerTile) {
712
+ const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
713
+ int totalRegs = (kSize / kTileSize ) * regsPerTile;
714
+
715
+ for (int reg = regsPerTile; reg < totalRegs; reg *= 2 ) {
716
+ registerBase.push_back ({0 , (reg / regsPerTile) * kTileSize });
717
+ }
718
+ };
719
+
720
+ // kDoubleTileSize is the k dimension of a tile when double rated
721
+ // mfma instructions are used.
722
+ const int kDoubleTileSize =
723
+ isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
724
+ // kTileSize is the actually k dimention of a tile, which is
725
+ // determined by kWidthDot.
726
+ const int kTileSize = kWidthDot * 64 / mDim ;
727
+ // We use kDoubleTileSize as a reference to check whether the given
728
+ // kWidthDot leads to double or single sub-tiles in each tile.
729
+ const int numSubtilesPerTile = (kTileSize == kDoubleTileSize ) ? 2 : 1 ;
730
+
731
+ // Extend register base for large K sizes.
732
+ if (numSubtilesPerTile == 2 )
733
+ registerBase.push_back ({0 , threadsPerSubtileK}); // Second subtile
734
+
735
+ extendRegisterBaseForKDim (kTileSize , numSubtilesPerTile);
736
+
737
+ // Extend lane base based on MFMA size.
738
+ std::vector<std::vector<int32_t >> laneBaseExt;
739
+
740
+ if (isMfma32) {
741
+ laneBaseExt = {{16 , 0 }, {0 , numSubtilesPerTile * threadsPerSubtileK}};
742
+ } else {
743
+ laneBaseExt = {{0 , numSubtilesPerTile * threadsPerSubtileK},
744
+ {0 , 2 * numSubtilesPerTile * threadsPerSubtileK}};
745
+ }
746
+ laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
747
+ };
700
748
701
- laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
749
+ if (isFP4)
750
+ populateFP4LL (kSize , mDim );
751
+ else
752
+ populateLL (elemBitWidth, kSize , kWidthDot , mDim );
702
753
703
754
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
704
755
// To assign them to actual matrix dimensions we associate with register
705
756
// `order` which is also [nonk, k] given we set kContig to false.
706
757
LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
707
758
{outDimNames[order[0 ]], outDimNames[order[1 ]]});
708
-
709
759
if (hasBatchDim) {
710
760
assert (order[2 ] == 0 );
711
761
// Extend the base vector with one value to accommodate for the batch
0 commit comments