@@ -477,23 +477,87 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
477477 return combineCtaCgaWithShape (tileLayout, getCTALayout (), shape);
478478}
479479
480+ LinearLayout chooseLLDsReadB64TrLayout (Attribute enc, ArrayRef<int64_t > shape,
481+ int32_t elemBitWidth) {
482+ using BaseTy = std::vector<std::vector<int32_t >>;
483+ // This function will derive the layout for the ds_read_b64_tr instruction
484+ // based on the input layout (LL/DotLayout/...)
485+ // The ds_read_b64_tr works on 64 bits per lane and in groups of 16 lanes.
486+
487+ // Using M-continuous 16-bit input tensor A as an example. Each lane will
488+ // load 4 consecutive elements (64-bit in total) along M. There are 4
489+ // consecutive lanes in total along M. Then the loaded elements are exchanged
490+ // withthin the MxK=16x4 "base unit".
491+ // K0 K1 K2 K3
492+ // +---+---+---+---+
493+ // M0 | | | | | M0, K[0-3]: T0
494+ // M1 | T | T | T | T | M1, K[0-3]: T1
495+ // M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
496+ // M3 | | | | | M3, K[0-3]: T3
497+ // +---+---+---+---+
498+ // M4 | | | | | M4, K[0-3]: T4
499+ // M5 | T | T | T | T | M5, K[0-3]: T5
500+ // M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
501+ // M7 | | | | | M7, K[0-3]: T7
502+ // +---+---+---+---+ ==>
503+ // M8 | | | | | M8, K[0-3]: T8
504+ // M9 | T | T | T | T | M9, K[0-3]: T9
505+ // M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
506+ // M11 | | | | | M11, K[0-3]: T11
507+ // +---+---+---+---+
508+ // M12 | | | | | M12, K[0-3]: T12
509+ // M13 | T | T | T | T | M13, K[0-3]: T13
510+ // M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
511+ // M15 | | | | | M15, K[0-3]: T15
512+ // +---+---+---+---+
513+
514+ // Given the layout represented by `enc` and shape, we can derive the layout
515+ // that ds_read_b64_tr need to have in order to perform a vectorized load of
516+ // the elements. This can be done by rearranging the inner 4x16 element base
517+ // unit in the LL by rearranging the first numReg register bases and the
518+ // first numLane lane bases.
519+ auto rotatePrefixes = [](BaseTy ®Base, std::size_t numReg,
520+ BaseTy &laneBase, std::size_t numLane) {
521+ // Concatenate prefixes of the two vectors. Lane first and then regs.
522+ // C D E F | A B
523+ // Then copy over numReg to the regBase and numLane to laneBase
524+ // C D | E F A B
525+ BaseTy baseUnit (laneBase.begin (), laneBase.begin () + numLane);
526+ llvm::append_range (
527+ baseUnit, llvm::make_range (regBase.begin (), regBase.begin () + numReg));
528+
529+ std::copy (baseUnit.begin (), baseUnit.begin () + numReg, regBase.begin ());
530+ std::copy (baseUnit.begin () + numReg, baseUnit.end (), laneBase.begin ());
531+ };
532+
533+ auto ctx = enc.getContext ();
534+ assert (elemBitWidth == 8 || elemBitWidth == 16 );
535+ // Get how many reg bases the ds_read_tr tile spans
536+ unsigned numRegBases = llvm::Log2_32 (64 / elemBitWidth);
537+ // 4 lane bases describe 16 lanes.
538+ unsigned numLaneBases = 4 ;
539+
540+ auto ldsTransLayout = triton::gpu::toLinearLayout (shape, enc);
541+ auto bases = ldsTransLayout.getBases ();
542+ auto kRegister = S (" register" );
543+ auto kLane = S (" lane" );
544+ rotatePrefixes (bases[kRegister ], numRegBases, bases[kLane ], numLaneBases);
545+
546+ return LinearLayout (bases, ldsTransLayout.getOutDims (), false );
547+ }
548+
480549LinearLayout chooseDotDsReadB64TrLayout (DotOperandEncodingAttr dotMfmaLayout,
481550 ArrayRef<int64_t > shape,
482551 int32_t elemBitWidth) {
483552 auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
484553 auto mDim = mfmaLayout.getInstrShape ()[0 ];
485554 assert (mDim == 16 || mDim == 32 );
486555
487- bool isFP4 = false ;
488- if (elemBitWidth == 4 ) {
489- // When doing ds_read_tr4 we actually write the LL as if it were on i8
490- // elements this is becasue LL needs to be described for the i8 tensor
491- // elements.
492- elemBitWidth = 8 ;
493- isFP4 = true ;
494- }
495-
496- assert (elemBitWidth == 16 || elemBitWidth == 8 );
556+ assert (elemBitWidth == 4 );
557+ // When doing ds_read_tr4 we actually write the LL as if it were on i8
558+ // elements this is becasue LL needs to be described for the i8 tensor
559+ // elements.
560+ elemBitWidth = 8 ;
497561
498562 auto rank = shape.size ();
499563 bool hasBatchDim = rank == 3 ;
@@ -520,143 +584,39 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
520584
521585 std::vector<std::vector<int32_t >> registerBase;
522586 std::vector<std::vector<int32_t >> laneBase;
523- auto populateFP4LL = [®isterBase, &laneBase](int kSize , int mDim ) {
524- const bool isMfma32 = (mDim == 32 );
525- // ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
526- // at i8 values for the ownership of register/lane since it's the data type
527- // of the tensor. Register dimension: what i8 in the tile are held by thread
528- // 0? Lane dimension: what i8 in the tile are held in register 0 of each
529- // thread?
530- registerBase.push_back ({1 , 0 });
531- registerBase.push_back ({2 , 0 });
532- registerBase.push_back ({4 , 0 });
533- registerBase.push_back ({0 , 16 });
534-
535- // If more than one tile needs to be loaded, populate registerBase
536- // dimension for the other tiles
537- const int kTileSize = isMfma32 ? 64 : 128 ;
538- for (int reg = kTileSize ; reg < kSize ; reg *= 2 ) {
539- registerBase.push_back ({0 , reg});
540- }
541-
542- // When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
543- // The LL for the two is different
544- laneBase.push_back ({0 , 1 });
545- laneBase.push_back ({0 , 2 });
546- laneBase.push_back ({0 , 4 });
547- laneBase.push_back ({0 , 8 });
548- if (mDim == 16 ) {
549- laneBase.push_back ({0 , 32 });
550- laneBase.push_back ({0 , 64 });
551- } else {
552- assert (mDim == 32 );
553- laneBase.push_back ({8 , 0 });
554- laneBase.push_back ({0 , 32 });
555- }
556- };
557- auto populateLL = [®isterBase, &laneBase](int elemBitWidth, int kSize ,
558- int kWidthDot , int mDim ) {
559- // Number of bits loaded by an LDS read. ds_read_tr primarily supports
560- // 64-bit loads for most element sizes (16b, 8b, 4b).
561- const int32_t ldsReadWidth = 64 ;
562- int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
563- const int elemByteWidth = elemBitWidth / 8 ;
564- const bool isMfma32 = (mDim == 32 );
565-
566- // For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
567- // of data. The smallest unit for transposition is a
568- // [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
569- // where each thread reads kWidthTransRead elements along the non-K
570- // dimension. Due to the transposition mechanism, each thread ends up with
571- // kWidthTransRead elements along the K dimension.
572- //
573- // The MFMA selection logic prioritizes double-rate MFMA instructions
574- // whenever possible:
575- //
576- // - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
577- // is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
578- //
579- // - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
580- // selected; otherwise (blockK ≤ k), mfma32x32xk is used.
581- //
582- // NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
583- // instructions are used.
584- //
585- // In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
586- // elements along the K dimension:
587- // - The first kWidthTransRead elements belong to the first sub-tile.
588- // - The next kWidthTransRead elements belong to the second sub-tile.
589- //
590- // These elements are then grouped into larger tiles, each consisting of
591- // 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
592- // for one MFMA instruction. The shape of these tiles depends on the MFMA
593- // instruction used.
594- //
595- // For single-rate MFMA instructions, each thread holds kWidthTransRead
596- // elements along the K dimension. This means that the larger tile
597- // (corresponding to one MFMA instruction) consists of 4 {16,
598- // kWidthTransRead} sub-tiles.
599-
600- // Populate register base for first subtile
601- for (int i = 1 ; i < kWidthTransRead ; i *= 2 ) {
602- registerBase.push_back ({i, 0 });
603- }
604-
605- const int threadsPerSubtileNonK = 16 / kWidthTransRead ;
606- const int threadsPerSubtileK = kWidthTransRead ;
607-
608- // Populate lane base for first subtile
609- for (int i = 1 ; i < threadsPerSubtileNonK; i *= 2 ) {
610- laneBase.push_back ({i * kWidthTransRead , 0 });
611- }
612- for (int i = 1 ; i < threadsPerSubtileK; i *= 2 ) {
613- laneBase.push_back ({0 , i});
614- }
615-
616- // Function to extend register base for multiple tiles K dim.
617- auto extendRegisterBaseForKDim = [&](int kTileSize ,
618- int numSubtilesPerTile) {
619- const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
620- int totalRegs = (kSize / kTileSize ) * regsPerTile;
621-
622- for (int reg = regsPerTile; reg < totalRegs; reg *= 2 ) {
623- registerBase.push_back ({0 , (reg / regsPerTile) * kTileSize });
624- }
625- };
626-
627- // kDoubleTileSize is the k dimension of a tile when double rated
628- // mfma instructions are used.
629- const int kDoubleTileSize =
630- isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
631- // kTileSize is the actually k dimention of a tile, which is
632- // determined by kWidthDot.
633- const int kTileSize = kWidthDot * 64 / mDim ;
634- // We use kDoubleTileSize as a reference to check whether the given
635- // kWidthDot leads to double or single sub-tiles in each tile.
636- const int numSubtilesPerTile = (kTileSize == kDoubleTileSize ) ? 2 : 1 ;
637-
638- // Extend register base for large K sizes.
639- if (numSubtilesPerTile == 2 )
640- registerBase.push_back ({0 , threadsPerSubtileK}); // Second subtile
641-
642- extendRegisterBaseForKDim (kTileSize , numSubtilesPerTile);
643-
644- // Extend lane base based on MFMA size.
645- std::vector<std::vector<int32_t >> laneBaseExt;
646-
647- if (isMfma32) {
648- laneBaseExt = {{16 , 0 }, {0 , numSubtilesPerTile * threadsPerSubtileK}};
649- } else {
650- laneBaseExt = {{0 , numSubtilesPerTile * threadsPerSubtileK},
651- {0 , 2 * numSubtilesPerTile * threadsPerSubtileK}};
652- }
653- laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
654- };
655587
656- if (isFP4)
657- populateFP4LL (kSize , mDim );
658- else
659- populateLL (elemBitWidth, kSize , kWidthDot , mDim );
588+ const bool isMfma32 = (mDim == 32 );
589+ // ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
590+ // at i8 values for the ownership of register/lane since it's the data type
591+ // of the tensor. Register dimension: what i8 in the tile are held by thread
592+ // 0? Lane dimension: what i8 in the tile are held in register 0 of each
593+ // thread?
594+ registerBase.push_back ({1 , 0 });
595+ registerBase.push_back ({2 , 0 });
596+ registerBase.push_back ({4 , 0 });
597+ registerBase.push_back ({0 , 16 });
598+
599+ // If more than one tile needs to be loaded, populate registerBase
600+ // dimension for the other tiles
601+ const int kTileSize = isMfma32 ? 64 : 128 ;
602+ for (int reg = kTileSize ; reg < kSize ; reg *= 2 ) {
603+ registerBase.push_back ({0 , reg});
604+ }
605+
606+ // When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
607+ // The LL for the two is different
608+ laneBase.push_back ({0 , 1 });
609+ laneBase.push_back ({0 , 2 });
610+ laneBase.push_back ({0 , 4 });
611+ laneBase.push_back ({0 , 8 });
612+ if (mDim == 16 ) {
613+ laneBase.push_back ({0 , 32 });
614+ laneBase.push_back ({0 , 64 });
615+ } else {
616+ assert (mDim == 32 );
617+ laneBase.push_back ({8 , 0 });
618+ laneBase.push_back ({0 , 32 });
619+ }
660620
661621 // Base vectors above are defined in a fixed order [non-k-dim, k-dim].
662622 // To assign them to actual matrix dimensions we associate with register
@@ -1444,8 +1404,12 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
14441404
14451405LinearLayout chooseDsReadB64TrLayout (Attribute enc, ArrayRef<int64_t > shape,
14461406 int32_t elemBitWidth) {
1447- auto dot = cast<DotOperandEncodingAttr>(enc);
1448- return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
1407+ if (elemBitWidth == 4 ) {
1408+ auto dot = cast<DotOperandEncodingAttr>(enc);
1409+ return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
1410+ } else {
1411+ return chooseLLDsReadB64TrLayout (enc, shape, elemBitWidth);
1412+ }
14491413}
14501414
14511415LinearLayout chooseScaledWmmaScaleLayout (
0 commit comments