@@ -393,12 +393,12 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
393393 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
394394}
395395
396- LinearLayout chooseDotDsReadB64Tr16Layout (DotOperandEncodingAttr dotMfmaLayout,
397- ArrayRef<int64_t > shape,
398- int32_t elemBitWidth) {
396+ LinearLayout chooseDotDsReadB64TrLayout (DotOperandEncodingAttr dotMfmaLayout,
397+ ArrayRef<int64_t > shape,
398+ int32_t elemBitWidth) {
399399 auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
400400 assert (mfmaLayout.getMDim () == 16 || mfmaLayout.getNDim () == 32 );
401- assert (elemBitWidth == 16 );
401+ assert (elemBitWidth == 16 || elemBitWidth == 8 );
402402
403403 auto rank = shape.size ();
404404 bool hasBatchDim = rank == 3 ;
@@ -407,6 +407,7 @@ LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
407407 // loads for most element sizes (16b, 8b, 4b).
408408 const int32_t ldsReadWidth = 64 ;
409409 int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
410+ const int elemByteWidth = elemBitWidth / 8 ;
410411 auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
411412
412413 int32_t kSize = shape[kDim ];
@@ -427,72 +428,92 @@ LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
427428 SmallVector<unsigned > order = dotMfmaLayout.getDefaultOrder ();
428429 std::swap (order[0 ], order[1 ]);
429430
430- // In the LDS transpose logic, each thread accesses 64 bits (8 bytes) of data.
431- // The smallest unit for transposing is a 4x4 sub-tile of threads, where each
432- // thread reads 4 16-bit elements along the non-K dimension, resulting in a
433- // [non-K, K] = {16, 4} sub-tile of elements. Because of transposing
434- // mechanism, thread ends up with 4 16-bit elements along K dim.
431+ // For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
432+ // of data. The smallest unit for transposition is a
433+ // [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
434+ // where each thread reads kWidthTransRead elements along the non-K dimension.
435+ // Due to the transposition mechanism, each thread ends up with
436+ // kWidthTransRead elements along the K dimension.
435437 //
436438 // The MFMA selection logic prioritizes double-rate MFMA instructions whenever
437- // possible. Specifically:
438- // - For MFMA operations that are non-K = 16, when blockK > 16, mfma16x16x32
439- // is selected; otherwise (blockK ≤ 16), mfma16x16x16 remains the choice.
440- // - For MFMA operations that are non-K = 32, when blockK > 8, mfma32x32x16 is
441- // selected; otherwise (blockK ≤ 8), mfma32x32x8 is used.
439+ // possible:
442440 //
443- // In double-rate MFMA instructions, each thread holds 8 elements along the K
444- // dimension.
445- // - The first 4 elements belong to the first sub-tile.
446- // - The next 4 elements belong to the second sub-tile.
441+ // - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
442+ // is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
447443 //
448- // We then group these into larger tiles, each consisting of 8 of these 16x4
449- // sub-tiles. These tiles correspond to data for one mfma instruction. The
450- // shapes of these tiles depend on the MFMA instruction used:
451- // 1. For mfma32x32x16, the tile shape is [non-K, K] = {32, 16}.
452- // 2. For mfma16x16x32, the tile shape is [non-K, K] = {16, 32}.
444+ // - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
445+ // selected; otherwise (blockK ≤ k), mfma32x32xk is used.
453446 //
454- // For single-rate mfma instructions, each thread holds 4 elements along K
455- // dimension. This means larger tile (that corresponds to one mfma
456- // instruction) consists of 4 16x4 sub-tiles.
457- std::vector<std::vector<int32_t >> registerBase = {{1 , 0 },
458- {2 , 0 }}; // first sub-tile
459- std::vector<std::vector<int32_t >> laneBase = {{kWidthTransRead , 0 },
460- {2 * kWidthTransRead , 0 },
461- {0 , 1 },
462- {0 , 2 }}; // first sub-tile
463-
464- // Extend register base for multiple tiles in K dimension (corresponding to
465- // multiple mfma instructions accross k dim).
466- auto populateRegisterBase = [&](int kTileSize ) {
467- const int regsPerTile = 8 ;
468- int numRegs = (kSize / kTileSize ) * regsPerTile;
469- for (int reg = regsPerTile; reg < numRegs; reg *= 2 ) {
447+ // NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
448+ // instructions are used.
449+ //
450+ // In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
451+ // elements along the K dimension:
452+ // - The first kWidthTransRead elements belong to the first sub-tile.
453+ // - The next kWidthTransRead elements belong to the second sub-tile.
454+ //
455+ // These elements are then grouped into larger tiles, each consisting of
456+ // 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
457+ // for one MFMA instruction. The shape of these tiles depends on the MFMA
458+ // instruction used.
459+ //
460+ // For single-rate MFMA instructions, each thread holds kWidthTransRead
461+ // elements along the K dimension. This means that the larger tile
462+ // (corresponding to one MFMA instruction) consists of 4 {16, kWidthTransRead}
463+ // sub-tiles.
464+ std::vector<std::vector<int32_t >> registerBase;
465+ std::vector<std::vector<int32_t >> laneBase;
466+
467+ // Populate register base for first subtile
468+ for (int i = 1 ; i < kWidthTransRead ; i *= 2 ) {
469+ registerBase.push_back ({i, 0 });
470+ }
471+
472+ const int threadsPerSubtileNonK = 16 / kWidthTransRead ;
473+ const int threadsPerSubtileK = kWidthTransRead ;
474+
475+ // Populate lane base for first subtile
476+ for (int i = 1 ; i < threadsPerSubtileNonK; i *= 2 ) {
477+ laneBase.push_back ({i * kWidthTransRead , 0 });
478+ }
479+ for (int i = 1 ; i < threadsPerSubtileK; i *= 2 ) {
480+ laneBase.push_back ({0 , i});
481+ }
482+
483+ // Function to extend register base for multiple tiles K dim.
484+ auto extendRegisterBaseForKDim = [&](int kTileSize ) {
485+ const int regsPerTile = kWidthTransRead * 2 ; // Two subtiles per tile
486+ int totalRegs = (kSize / kTileSize ) * regsPerTile;
487+
488+ for (int reg = regsPerTile; reg < totalRegs; reg *= 2 ) {
470489 registerBase.push_back ({0 , (reg / regsPerTile) * kTileSize });
471490 }
472491 };
473492
474493 const bool isMfma32 = (mfmaLayout.getMDim () == 32 );
475494 const bool isMfma16 = (mfmaLayout.getMDim () == 16 );
476- const int kTileSize = isMfma32 ? 16 : 32 ;
477-
478- if (kSize >= kTileSize ) {
479- // Handles mfma32x32x16 and mfma16x16x32 cases
480- assert (kWidthDot == 8 );
481- registerBase.push_back ({0 , 4 }); // second sub-tile
482- populateRegisterBase (kTileSize );
483- auto laneBaseExt = isMfma32
484- ? std::vector<std::vector<int32_t >>{{16 , 0 }, {0 , 8 }}
485- : std::vector<std::vector<int32_t >>{{0 , 8 }, {0 , 16 }};
486- laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
495+ const int kTileSize = isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
496+ const bool largeKSize = kSize >= kTileSize ;
497+
498+ // Extend register base for large K sizes.
499+ if (largeKSize) {
500+ registerBase.push_back ({0 , threadsPerSubtileK}); // Second subtile
501+ extendRegisterBaseForKDim (kTileSize );
502+ }
503+
504+ // Extend lane base based on MFMA size.
505+ const int numSubtilesPerTile = largeKSize ? 2 : 1 ;
506+ std::vector<std::vector<int32_t >> laneBaseExt;
507+
508+ if (isMfma32) {
509+ laneBaseExt = {{16 , 0 }, {0 , numSubtilesPerTile * threadsPerSubtileK}};
487510 } else {
488- // Handles mfma32x32x8 and mfma16x16x16 cases
489- assert (kWidthDot == 4 );
490- auto laneBaseExt = isMfma32
491- ? std::vector<std::vector<int32_t >>{{16 , 0 }, {0 , 4 }}
492- : std::vector<std::vector<int32_t >>{{0 , 4 }, {0 , 8 }};
493- laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
511+ laneBaseExt = {{0 , numSubtilesPerTile * threadsPerSubtileK},
512+ {0 , 2 * numSubtilesPerTile * threadsPerSubtileK}};
494513 }
495514
515+ laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
516+
496517 // Base vectors above are defined in a fixed order [non-k-dim, k-dim].
497518 // To assign them to actual matrix dimensions `order` array is used.
498519 // For operand A: non-k-dim -> dim0, k-dim -> dim1
@@ -516,10 +537,7 @@ LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
516537
517538 LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
518539 warpLayout.transposeOuts (outDimNames);
519- auto finalLayout =
520- combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
521-
522- return finalLayout;
540+ return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
523541}
524542
525543LinearLayout mfmaDotToLinearLayout (DotOperandEncodingAttr dotMfmaLayout,
@@ -1334,10 +1352,10 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
13341352 return chooseDotLdMatrixLayout (dot, shape, needTrans, elemBitWidth);
13351353}
13361354
1337- LinearLayout chooseDsReadB64Tr16Layout (Attribute enc, ArrayRef<int64_t > shape,
1338- int32_t elemBitWidth) {
1355+ LinearLayout chooseDsReadB64TrLayout (Attribute enc, ArrayRef<int64_t > shape,
1356+ int32_t elemBitWidth) {
13391357 auto dot = cast<DotOperandEncodingAttr>(enc);
1340- return chooseDotDsReadB64Tr16Layout (dot, shape, elemBitWidth);
1358+ return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
13411359}
13421360
13431361LinearLayout
0 commit comments