Skip to content

Commit de2ba39

Browse files
authored
[AMD] Refactor mfma layout (triton-lang#8213)
This PR made following changes to the exiting `AMDMfmaEncodingAttr`: - Replace `mDim` and `nDim` with `instrShape` in the form (M, N, K) to stay consistent with wmma: triton-lang#8174. - Change `elementType` to `elementBitWidth`. `elementBitWidth` only impacts the layout when it is 64. Previously, we are using a `mlir::Type` which is a little inconvenient to create and check in various places.
1 parent 799d846 commit de2ba39

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+557
-551
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,15 +1036,17 @@ An encoding for tensors that have been produced by MFMA matrix core instructions
10361036
available on AMD Instinct GPUs of CDNA architectures.
10371037

10381038
It is characterized by the following parameters:
1039-
- `version` indicates the GPU architecture:
1039+
- `version`: The GPU architecture:
10401040
- 1: gfx908: CDNA1
10411041
- 2: gfx90a: CDNA2
10421042
- 3: gfx942: CDNA3
10431043
- 4: gfx950: CDNA4
1044-
- `warpsPerCTA` indicates the warp layout in the block.
1045-
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
1046-
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
1044+
- `warpsPerCTA`: The warp layout in the block.
1045+
- `instrShape`: The shape in the form of (M, N, K) of the matrix.
1046+
- `isTransposed`: Indicates the result tensor is transposed so that it can be converted to dotOperand layout
10471047
without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel).
1048+
- `tilesPerWarp`: The tile layout within a warp. Defaults to unit tile layout, i.e., single tile on all dimensions.
1049+
- `elementBitWidth`: Bit width of the output element type. Supported values are 32 and 64. Defaults to 32.
10481050

10491051
Example 1:
10501052
Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32.
@@ -1154,25 +1156,27 @@ w2 w2 w3 w3
11541156
ins
11551157
"unsigned": $version,
11561158
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1157-
ArrayRefParameter<"unsigned">:$tilesPerWarp,
1158-
"unsigned":$MDim,
1159-
"unsigned":$NDim,
1159+
ArrayRefParameter<"unsigned">:$instrShape,
11601160
"bool":$isTransposed,
11611161
"CTALayoutAttr":$CTALayout,
1162-
DefaultValuedParameter<"std::optional<Type>", "FloatType::get($_ctxt, 32)">:$elementType
1162+
ArrayRefParameter<"unsigned">:$tilesPerWarp,
1163+
"unsigned":$elementBitWidth
11631164
);
11641165

11651166
let builders = [
11661167
AttrBuilder<(ins "unsigned":$version,
11671168
"ArrayRef<unsigned>":$warpsPerCTA,
1168-
"unsigned":$MDim,
1169-
"unsigned":$NDim,
1169+
"ArrayRef<unsigned>":$instrShape,
11701170
"bool":$isTransposed,
11711171
"CTALayoutAttr":$CTALayout,
1172-
"std::optional<Type>":$elementType), [{
1173-
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1174-
1175-
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout, elementType);
1172+
CArg<"ArrayRef<unsigned>", "{}">:$tpw,
1173+
CArg<"unsigned", "0">:$elementBitWidth), [{
1174+
SmallVector<unsigned> tilesPerWarp(tpw);
1175+
if (tilesPerWarp.empty())
1176+
tilesPerWarp = SmallVector<unsigned>(warpsPerCTA.size(), 1);
1177+
if (elementBitWidth == 0)
1178+
elementBitWidth = 32;
1179+
return $_get($_ctxt, version, warpsPerCTA, instrShape, isTransposed, CTALayout, tilesPerWarp, elementBitWidth);
11761180
}]>
11771181
];
11781182

@@ -1194,6 +1198,7 @@ w2 w2 w3 w3
11941198

11951199
let genVerifyDecl = 1;
11961200
let hasCustomAssemblyFormat = 1;
1201+
let skipDefaultBuilders = 1;
11971202
}
11981203

11991204
def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> {

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr,
574574
};
575575

576576
static LogicalResult parseType(AsmParser &parser, const NamedAttribute &attr,
577-
std::optional<Type> &value, StringRef desc) {
577+
Type &value, StringRef desc) {
578578
auto typeAttr = mlir::dyn_cast<TypeAttr>(attr.getValue());
579579
if (!typeAttr) {
580580
parser.emitError(parser.getNameLoc(), "expected a Type in ") << desc;
@@ -1168,33 +1168,27 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
11681168

11691169
unsigned version = 0;
11701170
SmallVector<unsigned> warpsPerCTA;
1171-
SmallVector<unsigned> tilesPerWarp;
11721171
SmallVector<unsigned> instrShape;
11731172
bool isTransposed;
11741173
std::optional<SmallVector<unsigned>> CTAsPerCGA;
11751174
std::optional<SmallVector<unsigned>> CTASplitNum;
11761175
std::optional<SmallVector<unsigned>> CTAOrder;
1177-
std::optional<Type> elementType;
1176+
SmallVector<unsigned> tilesPerWarp = {};
1177+
unsigned elementBitWidth = 32;
11781178

11791179
for (const NamedAttribute &attr : dict) {
11801180
if (attr.getName() == "version") {
1181-
if (parseUInt(parser, attr, version, "verison").failed())
1181+
if (parseUInt(parser, attr, version, "version").failed())
11821182
return {};
11831183
}
11841184
if (attr.getName() == "warpsPerCTA") {
11851185
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
11861186
return {};
11871187
}
1188-
if (attr.getName() == "tilesPerWarp") {
1189-
if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp")
1190-
.failed())
1191-
return {};
1192-
}
11931188
if (attr.getName() == "instrShape") {
11941189
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
11951190
return {};
11961191
}
1197-
11981192
if (attr.getName() == "isTransposed") {
11991193
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
12001194
return {};
@@ -1214,72 +1208,73 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
12141208
.failed())
12151209
return {};
12161210
}
1217-
if (attr.getName() == "elementType") {
1218-
if (parseType(parser, attr, elementType, "elementType").failed())
1211+
if (attr.getName() == "tilesPerWarp") {
1212+
if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp")
1213+
.failed())
1214+
return {};
1215+
}
1216+
if (attr.getName() == "elementBitWidth") {
1217+
if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed())
12191218
return {};
12201219
}
1221-
}
1222-
1223-
if (tilesPerWarp.empty()) {
1224-
tilesPerWarp.resize(warpsPerCTA.size(), 1);
12251220
}
12261221

12271222
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
12281223
parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size());
12291224
if (!CTALayout.has_value())
12301225
return {};
12311226

1227+
if (tilesPerWarp.empty())
1228+
tilesPerWarp = SmallVector<unsigned>(instrShape.size(), 1);
1229+
12321230
return parser.getChecked<AMDMfmaEncodingAttr>(
1233-
parser.getContext(), version, warpsPerCTA, tilesPerWarp, instrShape[0],
1234-
instrShape[1], isTransposed, *CTALayout, elementType);
1231+
parser.getContext(), version, warpsPerCTA, instrShape, isTransposed,
1232+
*CTALayout, tilesPerWarp, elementBitWidth);
12351233
}
12361234

12371235
void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
12381236
printer << "<{"
1239-
<< "version = " << getVersion() //
1240-
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]";
1237+
<< "version = " << getVersion() //
1238+
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]" //
1239+
<< ", instrShape = [" << getInstrShape() << "]";
12411240

1242-
auto tilesPerWarp = getTilesPerWarp();
1243-
if (!hasUnitTilesPerWarp()) {
1244-
printer << ", tilesPerWarp = [" << getTilesPerWarp() << "]";
1245-
}
1241+
printer << ", isTransposed = " << getIsTransposed();
12461242

1247-
printer << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" //
1248-
<< ", isTransposed = " << getIsTransposed();
12491243
maybePrintCTALayout(getContext(), printer, getCTALayout(),
12501244
/*rank=*/getRank());
1251-
if (getElementType() && !(getElementType()->isF32())) {
1252-
std::string typeStr;
1253-
llvm::raw_string_ostream rso(typeStr);
1254-
getElementType()->print(rso);
1255-
printer << ", elementType = " << rso.str();
1256-
}
1245+
1246+
auto tilesPerWarp = getTilesPerWarp();
1247+
if (!hasUnitTilesPerWarp())
1248+
printer << ", tilesPerWarp = [" << getTilesPerWarp() << "]";
1249+
1250+
auto elementBitWidth = getElementBitWidth();
1251+
if (elementBitWidth != 32)
1252+
printer << ", elementBitWidth = " << elementBitWidth;
1253+
12571254
printer << "}>";
12581255
}
12591256

12601257
LogicalResult AMDMfmaEncodingAttr::verify(
12611258
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
12621259
llvm::ArrayRef<unsigned int> warpsPerCTA,
1263-
llvm::ArrayRef<unsigned int> tilesPerWarp, unsigned mDim, unsigned nDim,
1264-
bool isTransposed, mlir::triton::gpu::CTALayoutAttr,
1265-
std::optional<Type> elementType) {
1260+
llvm::ArrayRef<unsigned int> instrShape, bool isTransposed,
1261+
mlir::triton::gpu::CTALayoutAttr, llvm::ArrayRef<unsigned int> tilesPerWarp,
1262+
unsigned elementBitWidth) {
12661263
if (!(version >= 0 && version <= 4)) {
12671264
return emitError() << "version must be in the [0, 4] range";
12681265
}
12691266

1267+
auto mDim = instrShape[0];
1268+
auto nDim = instrShape[1];
12701269
const std::array<std::pair<unsigned, unsigned>, 4> validDims = {
12711270
{{32, 32}, {16, 16}, {64, 4}, {4, 64}}};
12721271
if (!llvm::is_contained(validDims, std::make_pair(mDim, nDim))) {
12731272
return emitError() << "invalid (mDim, nDim) combination: (" << mDim << ", "
12741273
<< nDim << ")";
12751274
}
1276-
if (elementType && !(elementType->isF64() || elementType->isF32() ||
1277-
elementType->isInteger(32))) {
1278-
std::string typeStr;
1279-
llvm::raw_string_ostream rso(typeStr);
1280-
elementType->print(rso);
1281-
return emitError() << "element type must be f64, f32, i32, or none";
1282-
}
1275+
1276+
if (!(elementBitWidth == 32 || elementBitWidth == 64))
1277+
return emitError() << "elementBitWidth must be 32 or 64";
12831278

12841279
return success();
12851280
}
@@ -2181,8 +2176,9 @@ bool AMDMfmaEncodingAttr::hasUnitTilesPerWarp() const {
21812176

21822177
SmallVector<int64_t>
21832178
AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
2184-
unsigned mDim = getMDim();
2185-
unsigned nDim = getNDim();
2179+
auto mnkDim = getInstrShape();
2180+
unsigned mDim = mnkDim[0];
2181+
unsigned nDim = mnkDim[1];
21862182
assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) ||
21872183
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
21882184

@@ -2279,7 +2275,7 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand(
22792275
std::max(std::min(simdWidth / perPhase, innerDimLength / vectorSize), 1u);
22802276

22812277
// TODO (zhanglx): figure out better parameters for mfma4
2282-
if (getMDim() == 4)
2278+
if (getInstrShape()[0] == 4)
22832279
maxPhase = 4;
22842280

22852281
return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,10 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
373373
auto dimM = outDimNames[order[1]];
374374
auto dimN = outDimNames[order[0]];
375375

376-
unsigned mDim = getMDim();
377-
unsigned nDim = getNDim();
378-
auto elementType = getElementType();
379-
int height = (elementType && elementType->isF64()) ? 1 : 4;
376+
auto mDim = getInstrShape()[0];
377+
auto nDim = getInstrShape()[1];
378+
auto elementBitWidth = getElementBitWidth();
379+
int height = elementBitWidth == 64 ? 1 : 4;
380380
constexpr int warpSize = 64;
381381

382382
bool isTransposed = getIsTransposed();
@@ -453,8 +453,7 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
453453
// the first argument is 0), an empty layout is created, so this identity
454454
// layout will not introduce any new registers.
455455
tileLayout *= LinearLayout::identity1D(
456-
shape[nIndex] / (getNDim() * warpsPerCTAN * tilesPerWarpN), kRegister,
457-
dimN);
456+
shape[nIndex] / (nDim * warpsPerCTAN * tilesPerWarpN), kRegister, dimN);
458457
tileLayout *= LinearLayout::identity1D(tilesPerWarpM, kRegister, dimM);
459458

460459
// Finally, extend the layout across warps in the M dimension.
@@ -481,7 +480,7 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
481480
ArrayRef<int64_t> shape,
482481
int32_t elemBitWidth) {
483482
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
484-
auto mDim = mfmaLayout.getMDim();
483+
auto mDim = mfmaLayout.getInstrShape()[0];
485484
assert(mDim == 16 || mDim == 32);
486485

487486
bool isFP4 = false;
@@ -697,8 +696,8 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
697696
auto tilesPerWarp = mfmaLayout.getTilesPerWarp();
698697
auto tilePerWarpNonK = tilesPerWarp[nonKDimIndex];
699698

700-
auto mDim = mfmaLayout.getMDim();
701-
auto nDim = mfmaLayout.getNDim();
699+
auto mDim = mfmaLayout.getInstrShape()[0];
700+
auto nDim = mfmaLayout.getInstrShape()[1];
702701
auto opIdx = dotMfmaLayout.getOpIdx();
703702
auto nonKDim = opIdx == 0 ? mDim : nDim;
704703
constexpr int warpSize = 64;
@@ -1619,8 +1618,9 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) {
16191618

16201619
// We currently only support transposed [B]F16 MFMA32x32 and MFMA16x16 on
16211620
// CDNA4.
1622-
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
1623-
bool isMfma16 = mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16;
1621+
auto mnkDim = mfmaLayout.getInstrShape();
1622+
bool isMfma32 = mnkDim[0] == 32 && mnkDim[1] == 32;
1623+
bool isMfma16 = mnkDim[0] == 16 && mnkDim[1] == 16;
16241624

16251625
auto valShape = valType.getShape();
16261626
// For mfma16x16, to use in-wavefront swap, we need to make sure the tiles

python/src/gluon_ir.cc

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ struct GluonLayouts {
102102
py::handle AMDMFMALayout;
103103
py::handle AMDWMMALayout;
104104
py::handle PaddedSharedLayout;
105-
py::handle GluonDType;
106105

107106
GluonLayouts() {
108107
auto layouts =
@@ -128,7 +127,6 @@ struct GluonLayouts {
128127
py::object(layouts.attr("PaddedSharedLayout")).release();
129128

130129
auto core = py::module::import("triton.language.core");
131-
GluonDType = py::object(core.attr("dtype")).release();
132130
}
133131
};
134132

@@ -218,26 +216,10 @@ py::object layoutToGluon(Attribute layout) {
218216
return layouts.AutoLayout();
219217
} else if (auto amdMfma = dyn_cast<ttg::AMDMfmaEncodingAttr>(layout)) {
220218
auto ctaLayout = amdMfma.getCTALayout();
221-
std::vector<unsigned> instrShape{amdMfma.getMDim(), amdMfma.getNDim()};
222-
auto elemTypeOpt = amdMfma.getElementType();
223-
const char *typeName = "fp32";
224-
if (elemTypeOpt.has_value()) {
225-
auto elemType = elemTypeOpt.value();
226-
if (elemType.isF64()) {
227-
typeName = "fp64";
228-
} else if (elemType.isF32()) {
229-
typeName = "fp32";
230-
} else {
231-
// The AMDMfmaEncodingAttr mlir attribute has already verified element
232-
// type is fp64, fp32 or int32; so, the typeName here must be int32.
233-
typeName = "int32";
234-
}
235-
}
236-
237219
return layouts.AMDMFMALayout(
238-
amdMfma.getVersion(), instrShape, amdMfma.getIsTransposed(),
239-
toStdVector(amdMfma.getWarpsPerCTA()), layouts.GluonDType(typeName),
240-
toStdVector(amdMfma.getTilesPerWarp()),
220+
amdMfma.getVersion(), toStdVector(amdMfma.getInstrShape()),
221+
amdMfma.getIsTransposed(), toStdVector(amdMfma.getWarpsPerCTA()),
222+
amdMfma.getElementBitWidth(), toStdVector(amdMfma.getTilesPerWarp()),
241223
toStdVector(ctaLayout.getCTAsPerCGA()),
242224
toStdVector(ctaLayout.getCTASplitNum()),
243225
toStdVector(ctaLayout.getCTAOrder()));
@@ -376,18 +358,19 @@ void init_gluon_ir(py::module &&m) {
376358
})
377359
.def("get_amd_mfma_layout",
378360
[](GluonOpBuilder &self, unsigned version,
361+
std::vector<unsigned> &warpsPerCta,
379362
std::vector<unsigned> &instrShape, bool transposed,
380-
std::vector<unsigned> &warpsPerCta, mlir::Type elemType,
381-
std::vector<unsigned> &tilesPerWarp,
382363
std::vector<unsigned> &ctasPerCga,
383364
std::vector<unsigned> &ctaSplitNum,
384-
std::vector<unsigned> &ctaOrder) -> Attribute {
365+
std::vector<unsigned> &ctaOrder,
366+
std::vector<unsigned> &tilesPerWarp,
367+
unsigned elementBitWidth) -> Attribute {
385368
auto ctx = self.getContext();
386369
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
387370
ctx, ctasPerCga, ctaSplitNum, ctaOrder);
388371
return ttg::AMDMfmaEncodingAttr::get(
389-
ctx, version, warpsPerCta, tilesPerWarp, instrShape[0],
390-
instrShape[1], transposed, ctaLayout, elemType);
372+
ctx, version, warpsPerCta, instrShape, transposed, ctaLayout,
373+
tilesPerWarp, elementBitWidth);
391374
})
392375
.def("get_amd_wmma_layout",
393376
[](GluonOpBuilder &self, unsigned version, bool transposed,

0 commit comments

Comments
 (0)