@@ -391,6 +391,135 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
391391 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
392392}
393393
394+ LinearLayout chooseDotDsReadB64Tr16Layout (DotOperandEncodingAttr dotMfmaLayout,
395+ ArrayRef<int64_t > shape,
396+ int32_t elemBitWidth) {
397+ auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
398+ assert (mfmaLayout.getMDim () == 16 || mfmaLayout.getNDim () == 32 );
399+ assert (elemBitWidth == 16 );
400+
401+ auto rank = shape.size ();
402+ bool hasBatchDim = rank == 3 ;
403+ int32_t kWidthDot = dotMfmaLayout.getKWidth ();
404+ // Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit
405+ // loads for most element sizes (16b, 8b, 4b).
406+ const int32_t ldsReadWidth = 64 ;
407+ int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
408+ auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
409+
410+ int32_t kSize = shape[kDim ];
411+ auto warpsPerCTA = mfmaLayout.getWarpsPerCTA ();
412+
413+ MLIRContext *ctx = dotMfmaLayout.getContext ();
414+ SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
415+
416+ StringAttr kRegister = S (" register" );
417+ StringAttr kLane = S (" lane" );
418+ StringAttr kWarp = S (" warp" );
419+
420+ // register order
421+ // operand A: [1, 0] / [2, 1, 0]
422+ // operand B: [0, 1] / [1, 2, 0]
423+ // Regular dot mfma order for both cases is [k, nonk]/[k, nonk, batch]
424+ // For LDS transpose layout swap order to [nonk, k]/[nonk, k, batch]
425+ SmallVector<unsigned > order = triton::gpu::getOrder (dotMfmaLayout);
426+ std::swap (order[0 ], order[1 ]);
427+
428+ // In the LDS transpose logic, each thread accesses 64 bits (8 bytes) of data.
429+ // The smallest unit for transposing is a 4x4 sub-tile of threads, where each
430+ // thread reads 4 16-bit elements along the non-K dimension, resulting in a
431+ // [non-K, K] = {16, 4} sub-tile of elements. Because of transposing
432+ // mechanism, thread ends up with 4 16-bit elements along K dim.
433+ //
434+ // The MFMA selection logic prioritizes double-rate MFMA instructions whenever
435+ // possible. Specifically:
436+ // - For MFMA operations that are non-K = 16, when blockK > 16, mfma16x16x32
437+ // is selected; otherwise (blockK ≤ 16), mfma16x16x16 remains the choice.
438+ // - For MFMA operations that are non-K = 32, when blockK > 8, mfma32x32x16 is
439+ // selected; otherwise (blockK ≤ 8), mfma32x32x8 is used.
440+ //
441+ // In double-rate MFMA instructions, each thread holds 8 elements along the K
442+ // dimension.
443+ // - The first 4 elements belong to the first sub-tile.
444+ // - The next 4 elements belong to the second sub-tile.
445+ //
446+ // We then group these into larger tiles, each consisting of 8 of these 16x4
447+ // sub-tiles. These tiles correspond to data for one mfma instruction. The
448+ // shapes of these tiles depend on the MFMA instruction used:
449+ // 1. For mfma32x32x16, the tile shape is [non-K, K] = {32, 16}.
450+ // 2. For mfma16x16x32, the tile shape is [non-K, K] = {16, 32}.
451+ //
452+ // For single-rate mfma instructions, each thread holds 4 elements along K
453+ // dimension. This means larger tile (that corresponds to one mfma
454+ // instruction) consists of 4 16x4 sub-tiles.
455+ std::vector<std::vector<int32_t >> registerBase = {{1 , 0 },
456+ {2 , 0 }}; // first sub-tile
457+ std::vector<std::vector<int32_t >> laneBase = {{kWidthTransRead , 0 },
458+ {2 * kWidthTransRead , 0 },
459+ {0 , 1 },
460+ {0 , 2 }}; // first sub-tile
461+
462+ // Extend register base for multiple tiles in K dimension (corresponding to
463+ // multiple mfma instructions accross k dim).
464+ auto populateRegisterBase = [&](int kTileSize ) {
465+ const int regsPerTile = 8 ;
466+ int numRegs = (kSize / kTileSize ) * regsPerTile;
467+ for (int reg = regsPerTile; reg < numRegs; reg *= 2 ) {
468+ registerBase.push_back ({0 , (reg / regsPerTile) * kTileSize });
469+ }
470+ };
471+
472+ const bool isMfma32 = (mfmaLayout.getMDim () == 32 );
473+ const bool isMfma16 = (mfmaLayout.getMDim () == 16 );
474+ const int kTileSize = isMfma32 ? 16 : 32 ;
475+
476+ if (kSize >= kTileSize ) {
477+ // Handles mfma32x32x16 and mfma16x16x32 cases
478+ assert (kWidthDot == 8 );
479+ registerBase.push_back ({0 , 4 }); // second sub-tile
480+ populateRegisterBase (kTileSize );
481+ auto laneBaseExt = isMfma32
482+ ? std::vector<std::vector<int32_t >>{{16 , 0 }, {0 , 8 }}
483+ : std::vector<std::vector<int32_t >>{{0 , 8 }, {0 , 16 }};
484+ laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
485+ } else {
486+ // Handles mfma32x32x8 and mfma16x16x16 cases
487+ assert (kWidthDot == 4 );
488+ auto laneBaseExt = isMfma32
489+ ? std::vector<std::vector<int32_t >>{{16 , 0 }, {0 , 4 }}
490+ : std::vector<std::vector<int32_t >>{{0 , 4 }, {0 , 8 }};
491+ laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
492+ }
493+
494+ // Base vectors above are defined in a fixed order [non-k-dim, k-dim].
495+ // To assign them to actual matrix dimensions `order` array is used.
496+ // For operand A: non-k-dim -> dim0, k-dim -> dim1
497+ // For operand B: non-k-dim -> dim1, k-dim -> dim0
498+ LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
499+ {outDimNames[order[0 ]], outDimNames[order[1 ]]});
500+
501+ if (hasBatchDim) {
502+ assert (order[2 ] == 0 );
503+ // Extend the base vector with one value to accommodate for the batch
504+ // dimension, which appears at the last.
505+ tileLayout *= LinearLayout::identity1D (1 , kRegister , outDimNames[order[2 ]]);
506+ tileLayout *= LinearLayout::identity1D (1 , kLane , outDimNames[order[2 ]]);
507+ }
508+
509+ // warp order
510+ // common for both operand A and B: [0, 1] / [0, 1, 2]
511+ // in both cases it is [M dim, N dim]/[batch, M dim, N dim]
512+ SmallVector<unsigned > warpOrder = triton::gpu::getWarpOrder (dotMfmaLayout);
513+ LinearLayout warpLayout = identityStandardND (kWarp , warpsPerCTA, warpOrder);
514+
515+ LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
516+ warpLayout.transposeOuts (outDimNames);
517+ auto finalLayout =
518+ combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
519+
520+ return finalLayout;
521+ }
522+
394523LinearLayout mfmaDotToLinearLayout (DotOperandEncodingAttr dotMfmaLayout,
395524 ArrayRef<int64_t > shape) {
396525
@@ -1204,4 +1333,10 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
12041333 return chooseDotLdMatrixLayout (dot, shape, needTrans, elemBitWidth);
12051334}
12061335
1336+ LinearLayout chooseDsReadB64Tr16Layout (Attribute enc, ArrayRef<int64_t > shape,
1337+ int32_t elemBitWidth) {
1338+ auto dot = cast<DotOperandEncodingAttr>(enc);
1339+ return chooseDotDsReadB64Tr16Layout (dot, shape, elemBitWidth);
1340+ }
1341+
12071342} // namespace mlir::triton::gpu
0 commit comments