Skip to content

Commit be510cc

Browse files
authored
[BACKEND] Implement generic code to allow for dot_scaled(mmav3) and warp choices (triton-lang#5103)
Even though this refactor allows for sharing the code for choosing warps and mma version with the regular `DotOp`, we don't activate any of the two paths yet, given that we still have to fix a couple things for them to work. Putting this preliminary PR up to avoid packing too many things in one PR
1 parent e00903a commit be510cc

File tree

5 files changed

+214
-135
lines changed

5 files changed

+214
-135
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
371371
auto srcTy = op.getSrc().getType();
372372
auto dstTy = op.getType();
373373

374-
// TODO (Keren): Currently, we handle general mma/blocked/slice ->
375-
// mma/blocked/slice conversions.
376-
// The following tasks must be completed before we can remove the layoutIsOK
377-
// check:
374+
// TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere)
375+
// -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be
376+
// completed before we can remove the layoutIsOK check:
378377
// 1. Support for AMD's MFMA and WMMA
379378
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
380379
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,11 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
140140
// Do for all DotOperandEncodingAttr once we have LLs for all of them
141141
static bool isSupportedDotOpLayout(Attribute layout) {
142142
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
143+
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
144+
// - kWidth == 8
143145
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
144-
return mma.isAmpere() && dot.getKWidth() == 8;
146+
bool legacyLoweringIsBuggy = dot.getKWidth() >= 8;
147+
return legacyLoweringIsBuggy && mma.isAmpere();
145148
}
146149
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
147150
return true;

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,23 @@ LogicalResult UpcastMXFPOp::verify() {
5252
"all dimensions except the last must match between operands");
5353
}
5454

55-
auto dotEncoding =
56-
dyn_cast_or_null<DotOperandEncodingAttr>(xTy.getEncoding());
55+
auto layoutX = xTy.getEncoding();
56+
auto layoutScale = scaleTy.getEncoding();
57+
if (bool(layoutX) != bool(layoutScale)) {
58+
return emitOpError(
59+
"Expected either both or neither operands to have an encoding");
60+
}
61+
// Nothing to check if no encoding. This is used to infer the return type in
62+
// AccelerateMatmul.cpp
63+
if (!layoutX) {
64+
return success();
65+
}
66+
67+
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
5768
if (!dotEncoding) {
5869
return emitOpError("Expected a DotOperandEncodingAttr for values");
5970
}
60-
61-
auto blockedScale =
62-
dyn_cast_or_null<BlockedEncodingAttr>(scaleTy.getEncoding());
71+
auto blockedScale = dyn_cast<BlockedEncodingAttr>(layoutScale);
6372
if (!blockedScale) {
6473
return emitOpError("Expected a BlockOperandEncoding for scales");
6574
}
@@ -86,22 +95,23 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
8695
auto xShape = xTy.getShape();
8796

8897
auto encoding = xTy.getEncoding();
89-
if (!encoding) {
90-
return emitOptionalError(loc, "expected an encoding");
91-
}
92-
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
93-
return emitOptionalError(loc, "expected a dotOperand encoding");
94-
}
9598

9699
if (typeEncoded == ScaleDotElemType::E2M1) {
97-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
98-
auto newVEncoding = DotOperandEncodingAttr::get(
99-
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
100-
oldEncoding.getKWidth() * 2);
100+
RankedTensorType retTy;
101+
101102
auto newShape = SmallVector<int64_t>(xShape);
102103
newShape.back() *= 2;
103-
inferredReturnTypes.push_back(
104-
RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding));
104+
if (!encoding) {
105+
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
106+
} else {
107+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
108+
auto newVEncoding = DotOperandEncodingAttr::get(
109+
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
110+
oldEncoding.getKWidth() * 2);
111+
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
112+
newVEncoding);
113+
}
114+
inferredReturnTypes.push_back(retTy);
105115
} else {
106116
inferredReturnTypes.push_back(xTy);
107117
}

0 commit comments

Comments
 (0)