@@ -470,93 +470,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
470470 return combineCtaCgaWithShape (tileLayout, getCTALayout (), shape);
471471}
472472
473- std::optional<LinearLayout>
474- chooseLLDsReadTrLayout (Attribute enc, ArrayRef<int64_t > shape,
475- int32_t elemBitWidth, unsigned instBitWidth,
476- unsigned numLanesInShuffleGroup) {
477- using BaseTy = std::vector<std::vector<int32_t >>;
478- // This function will derive the layout for the ds_read_tr instruction
479- // based on the input layout (LL/DotLayout/...)
480- // The ds_read_tr instruction works on instBitWidth per lane and in groups of
481- // numLanesInShuffleGroup lanes.
482-
483- // In this example we look at ds_read_b64_tr (instBitWidth = 64) and
484- // numLanesInShuffleGroup = 16 with 64 lanes per warp. Using M-continuous
485- // 16-bit input tensor A as an example. Each lane will load 4 consecutive
486- // elements (64-bit in total) along M. There are 4 consecutive lanes in total
487- // along M. Then the loaded elements are exchanged within the MxK=16x4 "base
488- // unit".
489- // K0 K1 K2 K3
490- // +---+---+---+---+
491- // M0 | | | | | M0, K[0-3]: T0
492- // M1 | T | T | T | T | M1, K[0-3]: T1
493- // M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
494- // M3 | | | | | M3, K[0-3]: T3
495- // +---+---+---+---+
496- // M4 | | | | | M4, K[0-3]: T4
497- // M5 | T | T | T | T | M5, K[0-3]: T5
498- // M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
499- // M7 | | | | | M7, K[0-3]: T7
500- // +---+---+---+---+ ==>
501- // M8 | | | | | M8, K[0-3]: T8
502- // M9 | T | T | T | T | M9, K[0-3]: T9
503- // M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
504- // M11 | | | | | M11, K[0-3]: T11
505- // +---+---+---+---+
506- // M12 | | | | | M12, K[0-3]: T12
507- // M13 | T | T | T | T | M13, K[0-3]: T13
508- // M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
509- // M15 | | | | | M15, K[0-3]: T15
510- // +---+---+---+---+
511-
512- // Given the layout represented by `enc` and shape, we can derive the layout
513- // that ds_read_b64_tr need to have in order to perform a vectorized load of
514- // the elements. This can be done by rearranging the inner 4x16 element base
515- // unit in the LL by rearranging the first numReg register bases and the
516- // first numLane lane bases.
517- auto rotatePrefixes = [](BaseTy ®Base, std::size_t numReg,
518- BaseTy &laneBase, std::size_t numLane) {
519- // Concatenate prefixes of the two vectors. Lane first and then regs.
520- // C D E F | A B
521- // Then copy over numReg to the regBase and numLane to laneBase
522- // C D | E F A B
523- BaseTy baseUnit (laneBase.begin (), laneBase.begin () + numLane);
524- llvm::append_range (
525- baseUnit, llvm::make_range (regBase.begin (), regBase.begin () + numReg));
526-
527- std::copy (baseUnit.begin (), baseUnit.begin () + numReg, regBase.begin ());
528- std::copy (baseUnit.begin () + numReg, baseUnit.end (), laneBase.begin ());
529- };
530-
531- auto ctx = enc.getContext ();
532- assert (elemBitWidth == 8 || elemBitWidth == 16 );
533- // Get how many reg bases and tile bases the ds_read_tr tile spans
534- unsigned numRegBases = llvm::Log2_32 (instBitWidth / elemBitWidth);
535- unsigned numLaneBases = llvm::Log2_32 (numLanesInShuffleGroup);
536-
537- auto ldsTransLayout = triton::gpu::toLinearLayout (shape, enc);
538- auto bases = ldsTransLayout.getBases ();
539- auto kRegister = S (" register" );
540- auto kLane = S (" lane" );
541-
542- // Make sure that we have enough register bases to rotate, otherwise we
543- // can't return a valid ds_read_tr layout
544- if (ldsTransLayout.getInDimSizeLog2 (kRegister ) < numRegBases) {
545- return std::nullopt ;
546- }
547- // We should always have enough lanes
548- assert (ldsTransLayout.getInDimSizeLog2 (kLane ) >= numLaneBases);
549- rotatePrefixes (bases[kRegister ], numRegBases, bases[kLane ], numLaneBases);
550- // Scale types double the elements for a total of 16 vgpr (still only 16
551- // elements contiguous). Need to adjust the lane basis to reflect that
552- if (elemBitWidth == 8 && numLanesInShuffleGroup == 8 ) {
553- assert (ldsTransLayout.getInDimSizeLog2 (kLane ) >= (numLaneBases + 1 ));
554- std::swap (bases[kLane ][numLaneBases - 1 ], bases[kLane ][numLaneBases]);
555- }
556-
557- return LinearLayout (bases, ldsTransLayout.getOutDims (), false );
558- }
559-
560473std::optional<LinearLayout>
561474chooseDotDsReadTrLayout (DotOperandEncodingAttr dotMfmaLayout,
562475 ArrayRef<int64_t > shape, int32_t elemBitWidth,
@@ -1192,20 +1105,39 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11921105 LinearLayout::identity1D (encoding.getCTASplitN (), kCol , dims[1 ]);
11931106 auto newEncoding = TensorMemoryEncodingAttr::get (
11941107 ctx, encoding.getBlockM (), encoding.getBlockN (),
1195- encoding.getColStride (), encoding.getCTASplitM (), 1 );
1108+ encoding.getColStride (), encoding.getCTASplitM (), 1 ,
1109+ encoding.getTwoCTAs ());
11961110 return tensorMemoryToLinearLayout (
11971111 {shape[0 ], shape[1 ] / encoding.getCTASplitN ()}, newEncoding) *
11981112 split;
11991113 }
12001114 if (encoding.getCTASplitM () > 1 ) {
1201- auto split =
1202- LinearLayout::identity1D (encoding.getCTASplitM (), kCol , dims[0 ]);
1115+ auto splitM = encoding.getCTASplitM ();
1116+ auto blockM = encoding.getBlockM ();
1117+ bool isM64TwoCTA = blockM == 64 && encoding.getTwoCTAs ();
1118+ if (isM64TwoCTA) {
1119+ // blockM == 64 and twoCTAs is laid out as the transpose of 128xblockN
1120+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-b
1121+ blockM *= 2 ;
1122+ splitM /= 2 ;
1123+ }
1124+ auto split = LinearLayout::identity1D (splitM, kCol , dims[0 ]);
12031125 auto newEncoding = TensorMemoryEncodingAttr::get (
1204- ctx, encoding.getBlockM (), encoding.getBlockN (),
1205- encoding.getColStride (), 1 , encoding.getCTASplitN ());
1206- return tensorMemoryToLinearLayout (
1207- {shape[0 ] / encoding.getCTASplitM (), shape[1 ]}, newEncoding) *
1208- split;
1126+ ctx, blockM, encoding.getBlockN (), encoding.getColStride (), 1 ,
1127+ encoding.getCTASplitN (), encoding.getTwoCTAs ());
1128+ auto ret =
1129+ tensorMemoryToLinearLayout ({shape[0 ] / splitM, shape[1 ]}, newEncoding) *
1130+ split;
1131+ // In this case, we swap the basis of the last row and last column as per
1132+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-bny
1133+ if (isM64TwoCTA) {
1134+ auto bases = ret.getBases ();
1135+ auto &rowBases = bases[kRow ];
1136+ auto &colBases = bases[kCol ];
1137+ std::swap (rowBases[rowBases.size () - 1 ], colBases[colBases.size () - 1 ]);
1138+ ret = LinearLayout (bases, ret.getOutDims (), ret.isSurjective ());
1139+ }
1140+ return ret;
12091141 }
12101142 assert (encoding.getCTASplitM () == 1 && encoding.getCTASplitN () == 1 );
12111143
@@ -1461,14 +1393,10 @@ std::optional<LinearLayout>
14611393chooseDsReadTrLayout (Attribute enc, ArrayRef<int64_t > shape,
14621394 int32_t elemBitWidth, unsigned instBitWidth,
14631395 unsigned numLanesInShuffleGroup) {
1464- if (elemBitWidth == 4 ) {
1465- auto dot = cast<DotOperandEncodingAttr>(enc);
1466- return chooseDotDsReadTrLayout (dot, shape, elemBitWidth, instBitWidth,
1467- numLanesInShuffleGroup);
1468- } else {
1469- return chooseLLDsReadTrLayout (enc, shape, elemBitWidth, instBitWidth,
1470- numLanesInShuffleGroup);
1471- }
1396+ assert (elemBitWidth == 4 );
1397+ auto dot = cast<DotOperandEncodingAttr>(enc);
1398+ return chooseDotDsReadTrLayout (dot, shape, elemBitWidth, instBitWidth,
1399+ numLanesInShuffleGroup);
14721400}
14731401
14741402LinearLayout chooseScaledWmmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
0 commit comments