Skip to content

Commit 1514139

Browse files
authored
[AMD] Rework lowering of ds_read_tr for b8/b16 types (#8525)
This uses a similar lowering to ldmatrix where we model the instruction directly using the LL instead of modelling it as a transformation of the requested output LL. This codepath is used for B8/B16 and also FP4 packed along K types. The path for FP4 packed along M/N is still using the legacy way since it requires an input/output shape change that isn't as straightforward.
1 parent d4cce9f commit 1514139

File tree

5 files changed

+322
-727
lines changed

5 files changed

+322
-727
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -469,93 +469,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
469469
return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape);
470470
}
471471

472-
std::optional<LinearLayout>
473-
chooseLLDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
474-
int32_t elemBitWidth, unsigned instBitWidth,
475-
unsigned numLanesInShuffleGroup) {
476-
using BaseTy = std::vector<std::vector<int32_t>>;
477-
// This function will derive the layout for the ds_read_tr instruction
478-
// based on the input layout (LL/DotLayout/...)
479-
// The ds_read_tr instruction works on instBitWidth per lane and in groups of
480-
// numLanesInShuffleGroup lanes.
481-
482-
// In this example we look at ds_read_b64_tr (instBitWidth = 64) and
483-
// numLanesInShuffleGroup = 16 with 64 lanes per warp. Using M-continuous
484-
// 16-bit input tensor A as an example. Each lane will load 4 consecutive
485-
// elements (64-bit in total) along M. There are 4 consecutive lanes in total
486-
// along M. Then the loaded elements are exchanged within the MxK=16x4 "base
487-
// unit".
488-
// K0 K1 K2 K3
489-
// +---+---+---+---+
490-
// M0 | | | | | M0, K[0-3]: T0
491-
// M1 | T | T | T | T | M1, K[0-3]: T1
492-
// M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
493-
// M3 | | | | | M3, K[0-3]: T3
494-
// +---+---+---+---+
495-
// M4 | | | | | M4, K[0-3]: T4
496-
// M5 | T | T | T | T | M5, K[0-3]: T5
497-
// M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
498-
// M7 | | | | | M7, K[0-3]: T7
499-
// +---+---+---+---+ ==>
500-
// M8 | | | | | M8, K[0-3]: T8
501-
// M9 | T | T | T | T | M9, K[0-3]: T9
502-
// M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
503-
// M11 | | | | | M11, K[0-3]: T11
504-
// +---+---+---+---+
505-
// M12 | | | | | M12, K[0-3]: T12
506-
// M13 | T | T | T | T | M13, K[0-3]: T13
507-
// M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
508-
// M15 | | | | | M15, K[0-3]: T15
509-
// +---+---+---+---+
510-
511-
// Given the layout represented by `enc` and shape, we can derive the layout
512-
// that ds_read_b64_tr need to have in order to perform a vectorized load of
513-
// the elements. This can be done by rearranging the inner 4x16 element base
514-
// unit in the LL by rearranging the first numReg register bases and the
515-
// first numLane lane bases.
516-
auto rotatePrefixes = [](BaseTy &regBase, std::size_t numReg,
517-
BaseTy &laneBase, std::size_t numLane) {
518-
// Concatenate prefixes of the two vectors. Lane first and then regs.
519-
// C D E F | A B
520-
// Then copy over numReg to the regBase and numLane to laneBase
521-
// C D | E F A B
522-
BaseTy baseUnit(laneBase.begin(), laneBase.begin() + numLane);
523-
llvm::append_range(
524-
baseUnit, llvm::make_range(regBase.begin(), regBase.begin() + numReg));
525-
526-
std::copy(baseUnit.begin(), baseUnit.begin() + numReg, regBase.begin());
527-
std::copy(baseUnit.begin() + numReg, baseUnit.end(), laneBase.begin());
528-
};
529-
530-
auto ctx = enc.getContext();
531-
assert(elemBitWidth == 8 || elemBitWidth == 16);
532-
// Get how many reg bases and tile bases the ds_read_tr tile spans
533-
unsigned numRegBases = llvm::Log2_32(instBitWidth / elemBitWidth);
534-
unsigned numLaneBases = llvm::Log2_32(numLanesInShuffleGroup);
535-
536-
auto ldsTransLayout = triton::gpu::toLinearLayout(shape, enc);
537-
auto bases = ldsTransLayout.getBases();
538-
auto kRegister = S("register");
539-
auto kLane = S("lane");
540-
541-
// Make sure that we have enough register bases to rotate, otherwise we
542-
// can't return a valid ds_read_tr layout
543-
if (ldsTransLayout.getInDimSizeLog2(kRegister) < numRegBases) {
544-
return std::nullopt;
545-
}
546-
// We should always have enough lanes
547-
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= numLaneBases);
548-
rotatePrefixes(bases[kRegister], numRegBases, bases[kLane], numLaneBases);
549-
// Scale types double the elements for a total of 16 vgpr (still only 16
550-
// elements contiguous). Need to adjust the lane basis to reflect that
551-
if (elemBitWidth == 8 && numLanesInShuffleGroup == 8) {
552-
assert(ldsTransLayout.getInDimSizeLog2(kLane) >= (numLaneBases + 1));
553-
std::swap(bases[kLane][numLaneBases - 1], bases[kLane][numLaneBases]);
554-
}
555-
556-
return LinearLayout(bases, ldsTransLayout.getOutDims(), false);
557-
}
558-
559472
std::optional<LinearLayout>
560473
chooseDotDsReadTrLayout(DotOperandEncodingAttr dotMfmaLayout,
561474
ArrayRef<int64_t> shape, int32_t elemBitWidth,
@@ -1457,14 +1370,10 @@ std::optional<LinearLayout>
14571370
chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
14581371
int32_t elemBitWidth, unsigned instBitWidth,
14591372
unsigned numLanesInShuffleGroup) {
1460-
if (elemBitWidth == 4) {
1461-
auto dot = cast<DotOperandEncodingAttr>(enc);
1462-
return chooseDotDsReadTrLayout(dot, shape, elemBitWidth, instBitWidth,
1463-
numLanesInShuffleGroup);
1464-
} else {
1465-
return chooseLLDsReadTrLayout(enc, shape, elemBitWidth, instBitWidth,
1466-
numLanesInShuffleGroup);
1467-
}
1373+
assert(elemBitWidth == 4);
1374+
auto dot = cast<DotOperandEncodingAttr>(enc);
1375+
return chooseDotDsReadTrLayout(dot, shape, elemBitWidth, instBitWidth,
1376+
numLanesInShuffleGroup);
14681377
}
14691378

14701379
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,

0 commit comments

Comments
 (0)