@@ -469,93 +469,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
469469 return combineCtaCgaWithShape (tileLayout, getCTALayout (), shape);
470470}
471471
472- std::optional<LinearLayout>
473- chooseLLDsReadTrLayout (Attribute enc, ArrayRef<int64_t > shape,
474- int32_t elemBitWidth, unsigned instBitWidth,
475- unsigned numLanesInShuffleGroup) {
476- using BaseTy = std::vector<std::vector<int32_t >>;
477- // This function will derive the layout for the ds_read_tr instruction
478- // based on the input layout (LL/DotLayout/...)
479- // The ds_read_tr instruction works on instBitWidth per lane and in groups of
480- // numLanesInShuffleGroup lanes.
481-
482- // In this example we look at ds_read_b64_tr (instBitWidth = 64) and
483- // numLanesInShuffleGroup = 16 with 64 lanes per warp. Using M-continuous
484- // 16-bit input tensor A as an example. Each lane will load 4 consecutive
485- // elements (64-bit in total) along M. There are 4 consecutive lanes in total
486- // along M. Then the loaded elements are exchanged within the MxK=16x4 "base
487- // unit".
488- // K0 K1 K2 K3
489- // +---+---+---+---+
490- // M0 | | | | | M0, K[0-3]: T0
491- // M1 | T | T | T | T | M1, K[0-3]: T1
492- // M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
493- // M3 | | | | | M3, K[0-3]: T3
494- // +---+---+---+---+
495- // M4 | | | | | M4, K[0-3]: T4
496- // M5 | T | T | T | T | M5, K[0-3]: T5
497- // M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
498- // M7 | | | | | M7, K[0-3]: T7
499- // +---+---+---+---+ ==>
500- // M8 | | | | | M8, K[0-3]: T8
501- // M9 | T | T | T | T | M9, K[0-3]: T9
502- // M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
503- // M11 | | | | | M11, K[0-3]: T11
504- // +---+---+---+---+
505- // M12 | | | | | M12, K[0-3]: T12
506- // M13 | T | T | T | T | M13, K[0-3]: T13
507- // M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
508- // M15 | | | | | M15, K[0-3]: T15
509- // +---+---+---+---+
510-
511- // Given the layout represented by `enc` and shape, we can derive the layout
512- // that ds_read_b64_tr need to have in order to perform a vectorized load of
513- // the elements. This can be done by rearranging the inner 4x16 element base
514- // unit in the LL by rearranging the first numReg register bases and the
515- // first numLane lane bases.
516- auto rotatePrefixes = [](BaseTy ®Base, std::size_t numReg,
517- BaseTy &laneBase, std::size_t numLane) {
518- // Concatenate prefixes of the two vectors. Lane first and then regs.
519- // C D E F | A B
520- // Then copy over numReg to the regBase and numLane to laneBase
521- // C D | E F A B
522- BaseTy baseUnit (laneBase.begin (), laneBase.begin () + numLane);
523- llvm::append_range (
524- baseUnit, llvm::make_range (regBase.begin (), regBase.begin () + numReg));
525-
526- std::copy (baseUnit.begin (), baseUnit.begin () + numReg, regBase.begin ());
527- std::copy (baseUnit.begin () + numReg, baseUnit.end (), laneBase.begin ());
528- };
529-
530- auto ctx = enc.getContext ();
531- assert (elemBitWidth == 8 || elemBitWidth == 16 );
532- // Get how many reg bases and tile bases the ds_read_tr tile spans
533- unsigned numRegBases = llvm::Log2_32 (instBitWidth / elemBitWidth);
534- unsigned numLaneBases = llvm::Log2_32 (numLanesInShuffleGroup);
535-
536- auto ldsTransLayout = triton::gpu::toLinearLayout (shape, enc);
537- auto bases = ldsTransLayout.getBases ();
538- auto kRegister = S (" register" );
539- auto kLane = S (" lane" );
540-
541- // Make sure that we have enough register bases to rotate, otherwise we
542- // can't return a valid ds_read_tr layout
543- if (ldsTransLayout.getInDimSizeLog2 (kRegister ) < numRegBases) {
544- return std::nullopt ;
545- }
546- // We should always have enough lanes
547- assert (ldsTransLayout.getInDimSizeLog2 (kLane ) >= numLaneBases);
548- rotatePrefixes (bases[kRegister ], numRegBases, bases[kLane ], numLaneBases);
549- // Scale types double the elements for a total of 16 vgpr (still only 16
550- // elements contiguous). Need to adjust the lane basis to reflect that
551- if (elemBitWidth == 8 && numLanesInShuffleGroup == 8 ) {
552- assert (ldsTransLayout.getInDimSizeLog2 (kLane ) >= (numLaneBases + 1 ));
553- std::swap (bases[kLane ][numLaneBases - 1 ], bases[kLane ][numLaneBases]);
554- }
555-
556- return LinearLayout (bases, ldsTransLayout.getOutDims (), false );
557- }
558-
559472std::optional<LinearLayout>
560473chooseDotDsReadTrLayout (DotOperandEncodingAttr dotMfmaLayout,
561474 ArrayRef<int64_t > shape, int32_t elemBitWidth,
@@ -1457,14 +1370,10 @@ std::optional<LinearLayout>
14571370chooseDsReadTrLayout (Attribute enc, ArrayRef<int64_t > shape,
14581371 int32_t elemBitWidth, unsigned instBitWidth,
14591372 unsigned numLanesInShuffleGroup) {
1460- if (elemBitWidth == 4 ) {
1461- auto dot = cast<DotOperandEncodingAttr>(enc);
1462- return chooseDotDsReadTrLayout (dot, shape, elemBitWidth, instBitWidth,
1463- numLanesInShuffleGroup);
1464- } else {
1465- return chooseLLDsReadTrLayout (enc, shape, elemBitWidth, instBitWidth,
1466- numLanesInShuffleGroup);
1467- }
1373+ assert (elemBitWidth == 4 );
1374+ auto dot = cast<DotOperandEncodingAttr>(enc);
1375+ return chooseDotDsReadTrLayout (dot, shape, elemBitWidth, instBitWidth,
1376+ numLanesInShuffleGroup);
14681377}
14691378
14701379LinearLayout chooseScaledWmmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
0 commit comments