Skip to content

Commit 8868aca

Browse files
Merge commit '210c7b5bb29c01781c3e3053fe6bf28eb178347f'
2 parents 3b3a787 + 210c7b5 commit 8868aca

File tree

38 files changed

+2278
-863
lines changed

38 files changed

+2278
-863
lines changed

include/triton/Tools/LayoutUtils.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,41 @@ std::pair<int, ColumnAction>
147147
largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth,
148148
std::optional<int> maybeMaxVecElems = std::nullopt);
149149

150+
// Close cousin of doing zerosLike(tile) * divideLeft(cvt, tile)
151+
// This one is a tad more general in the sense that it allows to divide
152+
// cvt:
153+
// - register=1 -> (0, 1)
154+
// register=2 -> (8, 0)
155+
// register=4 -> (0, 8)
156+
// register=8 -> (0, 16)
157+
// register=16 -> (0, 32)
158+
// register=32 -> (0, 64)
159+
// register=64 -> (16, 0)
160+
// - lane=1 -> (0, 2)
161+
// lane=2 -> (0, 4)
162+
// lane=4 -> (1, 0)
163+
// lane=8 -> (2, 0)
164+
// lane=16 -> (4, 0)
165+
// - warp=1 -> (32, 0)
166+
// warp=2 -> (64, 0)
167+
// - block is a size 1 dimension
168+
// where out dims are: [row (size 128), col (size 128)]
169+
// tile:
170+
// - register=1 -> (0, 1)
171+
// register=2 -> (8, 0)
172+
// - lane=1 -> (0, 2)
173+
// lane=2 -> (0, 4)
174+
// lane=4 -> (1, 0)
175+
// lane=8 -> (2, 0)
176+
// lane=16 -> (4, 0)
177+
// - warp=1 -> (32, 0)
178+
// warp=2 -> (64, 0)
179+
// where out dims are: [row (size 128), col (size 8)]
180+
// which would not be possible to lower via the divideLeft approach as we
181+
// cannot divide by the tile given the `register=64 -> (16, 0)` basis.
182+
std::optional<LinearLayout> getReps(const LinearLayout &cvt,
183+
const LinearLayout &tile);
184+
150185
} // namespace mlir::triton
151186

152187
#endif // TRITON_TOOLS_LAYOUTUTILS_H

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4343
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
4444
"TRITON_F32_DEFAULT",
4545
"TRITON_PREFER_TMEM_16x256_LAYOUT",
46+
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
4647
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
4748
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
4849
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 112 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -477,23 +477,87 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
477477
return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape);
478478
}
479479

480+
LinearLayout chooseLLDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
481+
int32_t elemBitWidth) {
482+
using BaseTy = std::vector<std::vector<int32_t>>;
483+
// This function will derive the layout for the ds_read_b64_tr instruction
484+
// based on the input layout (LL/DotLayout/...)
485+
// The ds_read_b64_tr works on 64 bits per lane and in groups of 16 lanes.
486+
487+
// Using M-continuous 16-bit input tensor A as an example. Each lane will
488+
// load 4 consecutive elements (64-bit in total) along M. There are 4
489+
// consecutive lanes in total along M. Then the loaded elements are exchanged
490+
// withthin the MxK=16x4 "base unit".
491+
// K0 K1 K2 K3
492+
// +---+---+---+---+
493+
// M0 | | | | | M0, K[0-3]: T0
494+
// M1 | T | T | T | T | M1, K[0-3]: T1
495+
// M2 | 0 | 4 | 8 |12 | M2, K[0-3]: T2
496+
// M3 | | | | | M3, K[0-3]: T3
497+
// +---+---+---+---+
498+
// M4 | | | | | M4, K[0-3]: T4
499+
// M5 | T | T | T | T | M5, K[0-3]: T5
500+
// M6 | 1 | 5 | 9 |13 | M6, K[0-3]: T6
501+
// M7 | | | | | M7, K[0-3]: T7
502+
// +---+---+---+---+ ==>
503+
// M8 | | | | | M8, K[0-3]: T8
504+
// M9 | T | T | T | T | M9, K[0-3]: T9
505+
// M10 | 2 | 6 |10 |14 | M10, K[0-3]: T10
506+
// M11 | | | | | M11, K[0-3]: T11
507+
// +---+---+---+---+
508+
// M12 | | | | | M12, K[0-3]: T12
509+
// M13 | T | T | T | T | M13, K[0-3]: T13
510+
// M14 | 3 | 7 |11 |15 | M14, K[0-3]: T14
511+
// M15 | | | | | M15, K[0-3]: T15
512+
// +---+---+---+---+
513+
514+
// Given the layout represented by `enc` and shape, we can derive the layout
515+
// that ds_read_b64_tr need to have in order to perform a vectorized load of
516+
// the elements. This can be done by rearranging the inner 4x16 element base
517+
// unit in the LL by rearranging the first numReg register bases and the
518+
// first numLane lane bases.
519+
auto rotatePrefixes = [](BaseTy &regBase, std::size_t numReg,
520+
BaseTy &laneBase, std::size_t numLane) {
521+
// Concatenate prefixes of the two vectors. Lane first and then regs.
522+
// C D E F | A B
523+
// Then copy over numReg to the regBase and numLane to laneBase
524+
// C D | E F A B
525+
BaseTy baseUnit(laneBase.begin(), laneBase.begin() + numLane);
526+
llvm::append_range(
527+
baseUnit, llvm::make_range(regBase.begin(), regBase.begin() + numReg));
528+
529+
std::copy(baseUnit.begin(), baseUnit.begin() + numReg, regBase.begin());
530+
std::copy(baseUnit.begin() + numReg, baseUnit.end(), laneBase.begin());
531+
};
532+
533+
auto ctx = enc.getContext();
534+
assert(elemBitWidth == 8 || elemBitWidth == 16);
535+
// Get how many reg bases the ds_read_tr tile spans
536+
unsigned numRegBases = llvm::Log2_32(64 / elemBitWidth);
537+
// 4 lane bases describe 16 lanes.
538+
unsigned numLaneBases = 4;
539+
540+
auto ldsTransLayout = triton::gpu::toLinearLayout(shape, enc);
541+
auto bases = ldsTransLayout.getBases();
542+
auto kRegister = S("register");
543+
auto kLane = S("lane");
544+
rotatePrefixes(bases[kRegister], numRegBases, bases[kLane], numLaneBases);
545+
546+
return LinearLayout(bases, ldsTransLayout.getOutDims(), false);
547+
}
548+
480549
LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
481550
ArrayRef<int64_t> shape,
482551
int32_t elemBitWidth) {
483552
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
484553
auto mDim = mfmaLayout.getInstrShape()[0];
485554
assert(mDim == 16 || mDim == 32);
486555

487-
bool isFP4 = false;
488-
if (elemBitWidth == 4) {
489-
// When doing ds_read_tr4 we actually write the LL as if it were on i8
490-
// elements this is becasue LL needs to be described for the i8 tensor
491-
// elements.
492-
elemBitWidth = 8;
493-
isFP4 = true;
494-
}
495-
496-
assert(elemBitWidth == 16 || elemBitWidth == 8);
556+
assert(elemBitWidth == 4);
557+
// When doing ds_read_tr4 we actually write the LL as if it were on i8
558+
// elements this is becasue LL needs to be described for the i8 tensor
559+
// elements.
560+
elemBitWidth = 8;
497561

498562
auto rank = shape.size();
499563
bool hasBatchDim = rank == 3;
@@ -520,143 +584,39 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
520584

521585
std::vector<std::vector<int32_t>> registerBase;
522586
std::vector<std::vector<int32_t>> laneBase;
523-
auto populateFP4LL = [&registerBase, &laneBase](int kSize, int mDim) {
524-
const bool isMfma32 = (mDim == 32);
525-
// ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
526-
// at i8 values for the ownership of register/lane since it's the data type
527-
// of the tensor. Register dimension: what i8 in the tile are held by thread
528-
// 0? Lane dimension: what i8 in the tile are held in register 0 of each
529-
// thread?
530-
registerBase.push_back({1, 0});
531-
registerBase.push_back({2, 0});
532-
registerBase.push_back({4, 0});
533-
registerBase.push_back({0, 16});
534-
535-
// If more than one tile needs to be loaded, populate registerBase
536-
// dimension for the other tiles
537-
const int kTileSize = isMfma32 ? 64 : 128;
538-
for (int reg = kTileSize; reg < kSize; reg *= 2) {
539-
registerBase.push_back({0, reg});
540-
}
541-
542-
// When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
543-
// The LL for the two is different
544-
laneBase.push_back({0, 1});
545-
laneBase.push_back({0, 2});
546-
laneBase.push_back({0, 4});
547-
laneBase.push_back({0, 8});
548-
if (mDim == 16) {
549-
laneBase.push_back({0, 32});
550-
laneBase.push_back({0, 64});
551-
} else {
552-
assert(mDim == 32);
553-
laneBase.push_back({8, 0});
554-
laneBase.push_back({0, 32});
555-
}
556-
};
557-
auto populateLL = [&registerBase, &laneBase](int elemBitWidth, int kSize,
558-
int kWidthDot, int mDim) {
559-
// Number of bits loaded by an LDS read. ds_read_tr primarily supports
560-
// 64-bit loads for most element sizes (16b, 8b, 4b).
561-
const int32_t ldsReadWidth = 64;
562-
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
563-
const int elemByteWidth = elemBitWidth / 8;
564-
const bool isMfma32 = (mDim == 32);
565-
566-
// For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
567-
// of data. The smallest unit for transposition is a
568-
// [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
569-
// where each thread reads kWidthTransRead elements along the non-K
570-
// dimension. Due to the transposition mechanism, each thread ends up with
571-
// kWidthTransRead elements along the K dimension.
572-
//
573-
// The MFMA selection logic prioritizes double-rate MFMA instructions
574-
// whenever possible:
575-
//
576-
// - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
577-
// is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
578-
//
579-
// - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
580-
// selected; otherwise (blockK ≤ k), mfma32x32xk is used.
581-
//
582-
// NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
583-
// instructions are used.
584-
//
585-
// In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
586-
// elements along the K dimension:
587-
// - The first kWidthTransRead elements belong to the first sub-tile.
588-
// - The next kWidthTransRead elements belong to the second sub-tile.
589-
//
590-
// These elements are then grouped into larger tiles, each consisting of
591-
// 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
592-
// for one MFMA instruction. The shape of these tiles depends on the MFMA
593-
// instruction used.
594-
//
595-
// For single-rate MFMA instructions, each thread holds kWidthTransRead
596-
// elements along the K dimension. This means that the larger tile
597-
// (corresponding to one MFMA instruction) consists of 4 {16,
598-
// kWidthTransRead} sub-tiles.
599-
600-
// Populate register base for first subtile
601-
for (int i = 1; i < kWidthTransRead; i *= 2) {
602-
registerBase.push_back({i, 0});
603-
}
604-
605-
const int threadsPerSubtileNonK = 16 / kWidthTransRead;
606-
const int threadsPerSubtileK = kWidthTransRead;
607-
608-
// Populate lane base for first subtile
609-
for (int i = 1; i < threadsPerSubtileNonK; i *= 2) {
610-
laneBase.push_back({i * kWidthTransRead, 0});
611-
}
612-
for (int i = 1; i < threadsPerSubtileK; i *= 2) {
613-
laneBase.push_back({0, i});
614-
}
615-
616-
// Function to extend register base for multiple tiles K dim.
617-
auto extendRegisterBaseForKDim = [&](int kTileSize,
618-
int numSubtilesPerTile) {
619-
const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
620-
int totalRegs = (kSize / kTileSize) * regsPerTile;
621-
622-
for (int reg = regsPerTile; reg < totalRegs; reg *= 2) {
623-
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
624-
}
625-
};
626-
627-
// kDoubleTileSize is the k dimension of a tile when double rated
628-
// mfma instructions are used.
629-
const int kDoubleTileSize =
630-
isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
631-
// kTileSize is the actually k dimention of a tile, which is
632-
// determined by kWidthDot.
633-
const int kTileSize = kWidthDot * 64 / mDim;
634-
// We use kDoubleTileSize as a reference to check whether the given
635-
// kWidthDot leads to double or single sub-tiles in each tile.
636-
const int numSubtilesPerTile = (kTileSize == kDoubleTileSize) ? 2 : 1;
637-
638-
// Extend register base for large K sizes.
639-
if (numSubtilesPerTile == 2)
640-
registerBase.push_back({0, threadsPerSubtileK}); // Second subtile
641-
642-
extendRegisterBaseForKDim(kTileSize, numSubtilesPerTile);
643-
644-
// Extend lane base based on MFMA size.
645-
std::vector<std::vector<int32_t>> laneBaseExt;
646-
647-
if (isMfma32) {
648-
laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}};
649-
} else {
650-
laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK},
651-
{0, 2 * numSubtilesPerTile * threadsPerSubtileK}};
652-
}
653-
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
654-
};
655587

656-
if (isFP4)
657-
populateFP4LL(kSize, mDim);
658-
else
659-
populateLL(elemBitWidth, kSize, kWidthDot, mDim);
588+
const bool isMfma32 = (mDim == 32);
589+
// ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
590+
// at i8 values for the ownership of register/lane since it's the data type
591+
// of the tensor. Register dimension: what i8 in the tile are held by thread
592+
// 0? Lane dimension: what i8 in the tile are held in register 0 of each
593+
// thread?
594+
registerBase.push_back({1, 0});
595+
registerBase.push_back({2, 0});
596+
registerBase.push_back({4, 0});
597+
registerBase.push_back({0, 16});
598+
599+
// If more than one tile needs to be loaded, populate registerBase
600+
// dimension for the other tiles
601+
const int kTileSize = isMfma32 ? 64 : 128;
602+
for (int reg = kTileSize; reg < kSize; reg *= 2) {
603+
registerBase.push_back({0, reg});
604+
}
605+
606+
// When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
607+
// The LL for the two is different
608+
laneBase.push_back({0, 1});
609+
laneBase.push_back({0, 2});
610+
laneBase.push_back({0, 4});
611+
laneBase.push_back({0, 8});
612+
if (mDim == 16) {
613+
laneBase.push_back({0, 32});
614+
laneBase.push_back({0, 64});
615+
} else {
616+
assert(mDim == 32);
617+
laneBase.push_back({8, 0});
618+
laneBase.push_back({0, 32});
619+
}
660620

661621
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
662622
// To assign them to actual matrix dimensions we associate with register
@@ -1444,8 +1404,12 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
14441404

14451405
LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14461406
int32_t elemBitWidth) {
1447-
auto dot = cast<DotOperandEncodingAttr>(enc);
1448-
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
1407+
if (elemBitWidth == 4) {
1408+
auto dot = cast<DotOperandEncodingAttr>(enc);
1409+
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
1410+
} else {
1411+
return chooseLLDsReadB64TrLayout(enc, shape, elemBitWidth);
1412+
}
14491413
}
14501414

14511415
LinearLayout chooseScaledWmmaScaleLayout(

0 commit comments

Comments
 (0)