Skip to content

Commit 742d71b

Browse files
plognjenoplavsicantiagainst
authored
[AMD] Add tilesPerWarp parameter to mfma layout (#7283)
This PR introduces the tilesPerWarp parameter to the MFMA layout. Previously, the MFMA layout assumed that each warp within a CTA tile computed a single MFMA tile. When the tensor was larger than a single CTA tile, these tiles were repeated across the tensor. In this setup, the output tiles computed by each wave were strided by the number of warps per CTA in both row and column dimensions. For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the distribution of warps across the MFMA tiles looked like: w0 w1 w0 w1 w2 w3 w2 w3 w0 w1 w0 w1 w2 w3 w2 w3 The new tilesPerWarp parameter allows each warp to compute contiguous MFMA tiles in the row and/or column dimensions. Using the same example with tilesPerWarp = [2, 2], the layout becomes: w0 w0 w1 w1 w0 w0 w1 w1 w2 w2 w3 w3 w2 w2 w3 w3 While this is a general enhancement, the main motivation for introducing this parameter is to improve memory access efficiency for scale tensors in scaled dot operations. Specific patterns and use cases will be implemented in follow-up PRs. --------- Co-authored-by: Ognjen Plavsic <[email protected]> Co-authored-by: Lei Zhang <[email protected]>
1 parent c2c7b9b commit 742d71b

File tree

10 files changed

+581
-101
lines changed

10 files changed

+581
-101
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,11 @@ LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
282282
int numWarps);
283283

284284
// Create LinearLayout for scale in scaled mfma.
285-
LinearLayout chooseScaledMfmaScaleLayout(
286-
MLIRContext *ctx, int dotOperandIdx,
287-
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
288-
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
285+
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
286+
ArrayRef<int64_t> dotOperandShape,
287+
unsigned mfmaMDim,
288+
ArrayRef<unsigned> tilesPerWarp,
289+
ArrayRef<unsigned> warpsPerCTA);
289290

290291
// Create LinearLayout for nvidia mma tile.
291292
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,23 +1077,61 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10771077
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
10781078
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
10791079
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
1080+
1081+
Example 4:
1082+
This example demonstrates semantics of tilesPerWarp parameter. The MFMA layout (with tilesPerWarp=[1,1])
1083+
assumes that each warp within a CTA tile computes a single MFMA tile. When the tensor is larger than
1084+
a single CTA tile, these tiles are repeated across the tensor. In this setup, the output tiles computed
1085+
by each wave were strided by the number of warps per CTA tile in both row and column dimensions.
1086+
1087+
For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the distribution of warps across the MFMA
1088+
tiles looked like:
1089+
1090+
w0 w1 w0 w1
1091+
w2 w3 w2 w3
1092+
w0 w1 w0 w1
1093+
w2 w3 w2 w3
1094+
1095+
tilesPerWarp parameter allows each warp to compute contiguous MFMA tiles in the row and/or column dimensions.
1096+
Using the same example with tilesPerWarp = [2, 2], the layout becomes:
1097+
1098+
w0 w0 w1 w1
1099+
w0 w0 w1 w1
1100+
w2 w2 w3 w3
1101+
w2 w2 w3 w3
10801102
}];
10811103

10821104
let parameters = (
10831105
ins
10841106
"unsigned": $version,
10851107
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1108+
ArrayRefParameter<"unsigned">:$tilesPerWarp,
10861109
"unsigned":$MDim,
10871110
"unsigned":$NDim,
10881111
"bool":$isTransposed,
10891112
"CTALayoutAttr":$CTALayout
10901113
);
10911114

1115+
let builders = [
1116+
AttrBuilder<(ins "unsigned":$version,
1117+
"ArrayRef<unsigned>":$warpsPerCTA,
1118+
"unsigned":$MDim,
1119+
"unsigned":$NDim,
1120+
"bool":$isTransposed,
1121+
"CTALayoutAttr":$CTALayout), [{
1122+
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1123+
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout);
1124+
}]>
1125+
];
1126+
10921127
let extraClassDeclaration = extraDistributedDeclaration # [{
10931128
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
10941129
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
10951130
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
10961131

1132+
// Check if tilesPerWarp is 1 in every dimension.
1133+
bool hasUnitTilesPerWarp() const;
1134+
10971135
// Returns a swizzled shared layout matching this MFMA layout for the
10981136
// dot operand at the given |operandIdx| with |operandShape|.
10991137
SwizzledSharedEncodingAttr composeSharedLayoutForOperand(

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ using namespace mlir;
2121
using namespace mlir::triton::gpu;
2222

2323
constexpr int kPtrBitWidth = 64;
24-
2524
struct ConvertLayoutOpUsingLinearLayoutsConversion
2625
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
2726
const TargetInfoBase &targetInfo;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13161316

13171317
unsigned version = 0;
13181318
SmallVector<unsigned> warpsPerCTA;
1319+
SmallVector<unsigned> tilesPerWarp;
13191320
SmallVector<unsigned> instrShape;
13201321
bool isTransposed;
13211322
std::optional<SmallVector<unsigned>> CTAsPerCGA;
@@ -1331,6 +1332,11 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13311332
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
13321333
return {};
13331334
}
1335+
if (attr.getName() == "tilesPerWarp") {
1336+
if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp")
1337+
.failed())
1338+
return {};
1339+
}
13341340
if (attr.getName() == "instrShape") {
13351341
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
13361342
return {};
@@ -1357,21 +1363,31 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13571363
}
13581364
}
13591365

1366+
if (tilesPerWarp.empty()) {
1367+
tilesPerWarp.resize(warpsPerCTA.size(), 1);
1368+
}
1369+
13601370
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
13611371
parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size());
13621372
if (!CTALayout.has_value())
13631373
return {};
13641374

13651375
return parser.getChecked<AMDMfmaEncodingAttr>(
1366-
parser.getContext(), version, warpsPerCTA, instrShape[0], instrShape[1],
1367-
isTransposed, *CTALayout);
1376+
parser.getContext(), version, warpsPerCTA, tilesPerWarp, instrShape[0],
1377+
instrShape[1], isTransposed, *CTALayout);
13681378
}
13691379

13701380
void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13711381
printer << "<{"
1372-
<< "version = " << getVersion() //
1373-
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]" //
1374-
<< ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" //
1382+
<< "version = " << getVersion() //
1383+
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]";
1384+
1385+
auto tilesPerWarp = getTilesPerWarp();
1386+
if (!hasUnitTilesPerWarp()) {
1387+
printer << ", tilesPerWarp = [" << getTilesPerWarp() << "]";
1388+
}
1389+
1390+
printer << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" //
13751391
<< ", isTransposed = " << getIsTransposed();
13761392
maybePrintCTALayout(getContext(), printer, getCTALayout(),
13771393
/*rank=*/getRank());
@@ -1380,7 +1396,8 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13801396

13811397
LogicalResult AMDMfmaEncodingAttr::verify(
13821398
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
1383-
llvm::ArrayRef<unsigned int> warpsPerCTA, unsigned mDim, unsigned nDim,
1399+
llvm::ArrayRef<unsigned int> warpsPerCTA,
1400+
llvm::ArrayRef<unsigned int> tilesPerWarp, unsigned mDim, unsigned nDim,
13841401
bool isTransposed, mlir::triton::gpu::CTALayoutAttr) {
13851402
if (!(version >= 0 && version <= 4)) {
13861403
return emitError() << "version must be in the [0, 4] range";
@@ -1873,6 +1890,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getCTASplitNum() const {
18731890
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
18741891
}
18751892

1893+
bool AMDMfmaEncodingAttr::hasUnitTilesPerWarp() const {
1894+
return !llvm::any_of(getTilesPerWarp(), [](int x) { return x != 1; });
1895+
}
1896+
18761897
SmallVector<int64_t>
18771898
AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
18781899
unsigned mDim = getMDim();
@@ -1908,21 +1929,27 @@ AMDMfmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
19081929
auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx);
19091930
auto rank = operandShape.size();
19101931
auto warpsPerCTA = getWarpsPerCTA();
1932+
auto tilesPerWarp = getTilesPerWarp();
1933+
19111934
int numRepBatch =
19121935
rank == 3 ? std::max<int64_t>(1, operandShape[0] / warpsPerCTA[0]) : 1;
19131936
if (opIdx == 0)
19141937
return {
19151938
numRepBatch,
19161939
std::max<int64_t>(1, operandShape[rank - 2] /
1917-
(operandTileShape[0] * warpsPerCTA[rank - 2])),
1940+
(operandTileShape[0] * tilesPerWarp[rank - 2] *
1941+
warpsPerCTA[rank - 2])) *
1942+
tilesPerWarp[rank - 2],
19181943
std::max<int64_t>(1, operandShape[rank - 1] / operandTileShape[1])};
19191944
else {
19201945
assert(opIdx == 1);
19211946
return {
19221947
numRepBatch,
19231948
std::max<int64_t>(1, operandShape[rank - 2] / operandTileShape[0]),
1924-
std::max<int64_t>(1, operandShape[rank - 1] / (operandTileShape[1] *
1925-
warpsPerCTA[rank - 1]))};
1949+
std::max<int64_t>(1, operandShape[rank - 1] /
1950+
(operandTileShape[1] * tilesPerWarp[rank - 1] *
1951+
warpsPerCTA[rank - 1])) *
1952+
tilesPerWarp[rank - 1]};
19261953
}
19271954
}
19281955

0 commit comments

Comments
 (0)