@@ -470,93 +470,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
470470 return combineCtaCgaWithShape (tileLayout, getCTALayout (), shape);
471471}
472472
473- std::optional<LinearLayout>
474- chooseLLDsReadTrLayout (Attribute enc, ArrayRef<int64_t > shape,
475- int32_t elemBitWidth, unsigned instBitWidth,
476- unsigned numLanesInShuffleGroup) {
477- using BaseTy = std::vector<std::vector<int32_t >>;
478- // This function will derive the layout for the ds_read_tr instruction
479- // based on the input layout (LL/DotLayout/...)
480- // The ds_read_tr instruction works on instBitWidth per lane and in groups of
481- // numLanesInShuffleGroup lanes.
482-
483- // In this example we look at ds_read_b64_tr (instBitWidth = 64) and
484- // numLanesInShuffleGroup = 16 with 64 lanes per warp. Using M-continuous
485- // 16-bit input tensor A as an example. Each lane will load 4 consecutive
486- // elements (64-bit in total) along M. There are 4 consecutive lanes in total
487- // along M. Then the loaded elements are exchanged within the MxK=16x4 "base
488- // unit".
489- // K0 K1 K2 K3
490- // +---+---+---+---+
491- // M0 | | | | | M0, K[0-3]: T0
492- // M1 | T | T | T | T | M1, K[0-3]: T1
493- // M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
494- // M3 | | | | | M3, K[0-3]: T3
495- // +---+---+---+---+
496- // M4 | | | | | M4, K[0-3]: T4
497- // M5 | T | T | T | T | M5, K[0-3]: T5
498- // M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
499- // M7 | | | | | M7, K[0-3]: T7
500- // +---+---+---+---+ ==>
501- // M8 | | | | | M8, K[0-3]: T8
502- // M9 | T | T | T | T | M9, K[0-3]: T9
503- // M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
504- // M11 | | | | | M11, K[0-3]: T11
505- // +---+---+---+---+
506- // M12 | | | | | M12, K[0-3]: T12
507- // M13 | T | T | T | T | M13, K[0-3]: T13
508- // M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
509- // M15 | | | | | M15, K[0-3]: T15
510- // +---+---+---+---+
511-
512- // Given the layout represented by `enc` and shape, we can derive the layout
513- // that ds_read_b64_tr need to have in order to perform a vectorized load of
514- // the elements. This can be done by rearranging the inner 4x16 element base
515- // unit in the LL by rearranging the first numReg register bases and the
516- // first numLane lane bases.
517- auto rotatePrefixes = [](BaseTy ®Base, std::size_t numReg,
518- BaseTy &laneBase, std::size_t numLane) {
519- // Concatenate prefixes of the two vectors. Lane first and then regs.
520- // C D E F | A B
521- // Then copy over numReg to the regBase and numLane to laneBase
522- // C D | E F A B
523- BaseTy baseUnit (laneBase.begin (), laneBase.begin () + numLane);
524- llvm::append_range (
525- baseUnit, llvm::make_range (regBase.begin (), regBase.begin () + numReg));
526-
527- std::copy (baseUnit.begin (), baseUnit.begin () + numReg, regBase.begin ());
528- std::copy (baseUnit.begin () + numReg, baseUnit.end (), laneBase.begin ());
529- };
530-
531- auto ctx = enc.getContext ();
532- assert (elemBitWidth == 8 || elemBitWidth == 16 );
533- // Get how many reg bases and tile bases the ds_read_tr tile spans
534- unsigned numRegBases = llvm::Log2_32 (instBitWidth / elemBitWidth);
535- unsigned numLaneBases = llvm::Log2_32 (numLanesInShuffleGroup);
536-
537- auto ldsTransLayout = triton::gpu::toLinearLayout (shape, enc);
538- auto bases = ldsTransLayout.getBases ();
539- auto kRegister = S (" register" );
540- auto kLane = S (" lane" );
541-
542- // Make sure that we have enough register bases to rotate, otherwise we
543- // can't return a valid ds_read_tr layout
544- if (ldsTransLayout.getInDimSizeLog2 (kRegister ) < numRegBases) {
545- return std::nullopt ;
546- }
547- // We should always have enough lanes
548- assert (ldsTransLayout.getInDimSizeLog2 (kLane ) >= numLaneBases);
549- rotatePrefixes (bases[kRegister ], numRegBases, bases[kLane ], numLaneBases);
550- // Scale types double the elements for a total of 16 vgpr (still only 16
551- // elements contiguous). Need to adjust the lane basis to reflect that
552- if (elemBitWidth == 8 && numLanesInShuffleGroup == 8 ) {
553- assert (ldsTransLayout.getInDimSizeLog2 (kLane ) >= (numLaneBases + 1 ));
554- std::swap (bases[kLane ][numLaneBases - 1 ], bases[kLane ][numLaneBases]);
555- }
556-
557- return LinearLayout (bases, ldsTransLayout.getOutDims (), false );
558- }
559-
560473std::optional<LinearLayout>
561474chooseDotDsReadTrLayout (DotOperandEncodingAttr dotMfmaLayout,
562475 ArrayRef<int64_t > shape, int32_t elemBitWidth,
@@ -1461,14 +1374,10 @@ std::optional<LinearLayout>
14611374chooseDsReadTrLayout (Attribute enc, ArrayRef<int64_t > shape,
14621375 int32_t elemBitWidth, unsigned instBitWidth,
14631376 unsigned numLanesInShuffleGroup) {
1464- if (elemBitWidth == 4 ) {
1465- auto dot = cast<DotOperandEncodingAttr>(enc);
1466- return chooseDotDsReadTrLayout (dot, shape, elemBitWidth, instBitWidth,
1467- numLanesInShuffleGroup);
1468- } else {
1469- return chooseLLDsReadTrLayout (enc, shape, elemBitWidth, instBitWidth,
1470- numLanesInShuffleGroup);
1471- }
1377+ assert (elemBitWidth == 4 );
1378+ auto dot = cast<DotOperandEncodingAttr>(enc);
1379+ return chooseDotDsReadTrLayout (dot, shape, elemBitWidth, instBitWidth,
1380+ numLanesInShuffleGroup);
14721381}
14731382
14741383LinearLayout chooseScaledWmmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
0 commit comments