Skip to content

Commit 8a45291

Browse files
authored
[AMD] Support fp64 MFMA instructions (#7461)
This commit adds support for lowering fp64 dot to MFMA intrinsics in the AMD backend.
1 parent 6bdb64a commit 8a45291

File tree

11 files changed

+134
-44
lines changed

11 files changed

+134
-44
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,8 @@ w2 w2 w3 w3
11091109
"unsigned":$MDim,
11101110
"unsigned":$NDim,
11111111
"bool":$isTransposed,
1112-
"CTALayoutAttr":$CTALayout
1112+
"CTALayoutAttr":$CTALayout,
1113+
DefaultValuedParameter<"std::optional<Type>", "FloatType::get($_ctxt, 32)">:$elementType
11131114
);
11141115

11151116
let builders = [
@@ -1118,9 +1119,11 @@ w2 w2 w3 w3
11181119
"unsigned":$MDim,
11191120
"unsigned":$NDim,
11201121
"bool":$isTransposed,
1121-
"CTALayoutAttr":$CTALayout), [{
1122+
"CTALayoutAttr":$CTALayout,
1123+
"std::optional<Type>":$elementType), [{
11221124
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1123-
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout);
1125+
1126+
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout, elementType);
11241127
}]>
11251128
];
11261129

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,17 @@ static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr,
561561
return parseBoolAttrValue(parser, attr.getValue(), value, desc);
562562
};
563563

564+
static LogicalResult parseType(AsmParser &parser, const NamedAttribute &attr,
565+
std::optional<Type> &value, StringRef desc) {
566+
auto typeAttr = mlir::dyn_cast<TypeAttr>(attr.getValue());
567+
if (!typeAttr) {
568+
parser.emitError(parser.getNameLoc(), "expected a Type in ") << desc;
569+
return failure();
570+
}
571+
value = typeAttr.getValue();
572+
return success();
573+
}
574+
564575
// Print the CTALayout if it's not equal to the default.
565576
static void maybePrintCTALayout(mlir::MLIRContext *context,
566577
mlir::AsmPrinter &printer, CTALayoutAttr layout,
@@ -1327,6 +1338,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13271338
std::optional<SmallVector<unsigned>> CTAsPerCGA;
13281339
std::optional<SmallVector<unsigned>> CTASplitNum;
13291340
std::optional<SmallVector<unsigned>> CTAOrder;
1341+
std::optional<Type> elementType;
13301342

13311343
for (const NamedAttribute &attr : dict) {
13321344
if (attr.getName() == "version") {
@@ -1366,6 +1378,10 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13661378
.failed())
13671379
return {};
13681380
}
1381+
if (attr.getName() == "elementType") {
1382+
if (parseType(parser, attr, elementType, "elementType").failed())
1383+
return {};
1384+
}
13691385
}
13701386

13711387
if (tilesPerWarp.empty()) {
@@ -1379,7 +1395,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13791395

13801396
return parser.getChecked<AMDMfmaEncodingAttr>(
13811397
parser.getContext(), version, warpsPerCTA, tilesPerWarp, instrShape[0],
1382-
instrShape[1], isTransposed, *CTALayout);
1398+
instrShape[1], isTransposed, *CTALayout, elementType);
13831399
}
13841400

13851401
void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
@@ -1396,21 +1412,35 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13961412
<< ", isTransposed = " << getIsTransposed();
13971413
maybePrintCTALayout(getContext(), printer, getCTALayout(),
13981414
/*rank=*/getRank());
1415+
if (getElementType() && !(getElementType()->isF32())) {
1416+
std::string typeStr;
1417+
llvm::raw_string_ostream rso(typeStr);
1418+
getElementType()->print(rso);
1419+
printer << ", elementType = " << rso.str();
1420+
}
13991421
printer << "}>";
14001422
}
14011423

14021424
LogicalResult AMDMfmaEncodingAttr::verify(
14031425
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
14041426
llvm::ArrayRef<unsigned int> warpsPerCTA,
14051427
llvm::ArrayRef<unsigned int> tilesPerWarp, unsigned mDim, unsigned nDim,
1406-
bool isTransposed, mlir::triton::gpu::CTALayoutAttr) {
1428+
bool isTransposed, mlir::triton::gpu::CTALayoutAttr,
1429+
std::optional<Type> elementType) {
14071430
if (!(version >= 0 && version <= 4)) {
14081431
return emitError() << "version must be in the [0, 4] range";
14091432
}
14101433
if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) {
14111434
return emitError()
14121435
<< "(M, N) cases other than (32, 32) or (16, 16) unimplemented";
14131436
}
1437+
if (elementType && !(elementType->isF64() || elementType->isF32() ||
1438+
elementType->isInteger(32))) {
1439+
std::string typeStr;
1440+
llvm::raw_string_ostream rso(typeStr);
1441+
elementType->print(rso);
1442+
return emitError() << "element type must be f64, f32, i32, or none";
1443+
}
14141444

14151445
return success();
14161446
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -438,25 +438,56 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
438438
{outDimNames[order[0]], outDimNames[order[1]]});
439439
} else {
440440
assert(getMDim() == 16);
441-
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.
442-
//
443-
// For the register (i.e., element) dimension, these 4 elements are along
444-
// the matrix C's M dimension, with 4 consecutive elements spanning 4 rows.
445-
//
446-
// For the lane (i.e., thread) dimension, these threads are along the
447-
// matrix C's N dimension, with 16 consecutive threads covering a whole
448-
// row and the next 16 threads start after a gap spanning 4 rows.
449-
tileLayout = LinearLayout(
450-
{{kRegister, {{0, 1}, {0, 2}}},
451-
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
452-
{outDimNames[order[0]], outDimNames[order[1]]});
453-
// For mfma.transposed layout, the element ownership among threads are
454-
// "transposed" within each warp.
455-
if (getIsTransposed())
441+
auto elementType = getElementType();
442+
if (!(elementType && elementType->isF64())) {
443+
// For mfma with 16x16 output (<= 32 bits), each of the 64 threads holds 4
444+
// elements.
445+
//
446+
// For the register (i.e., element) dimension, these 4 elements are along
447+
// the matrix C's M dimension, with 4 consecutive elements spanning 4
448+
// rows.
449+
//
450+
// For the lane (i.e., thread) dimension, these threads are along the
451+
// matrix C's N dimension, with 16 consecutive threads covering a whole
452+
// row and the next 16 threads start after a gap spanning 4 rows.
453+
tileLayout = LinearLayout(
454+
{{kRegister, {{0, 1}, {0, 2}}},
455+
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
456+
{outDimNames[order[0]], outDimNames[order[1]]});
457+
// For mfma.transposed layout, the element ownership among threads are
458+
// "transposed" within each warp.
459+
if (getIsTransposed())
460+
tileLayout = LinearLayout(
461+
{{kRegister, {{1, 0}, {2, 0}}},
462+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
463+
{outDimNames[order[0]], outDimNames[order[1]]});
464+
465+
} else {
466+
// For 64 bit mfma with 16x16 output, each of the 64 threads holds 4
467+
// elements across 8 VGPRs. each 64 bit element is split across pairs of 2
468+
// VGPRs each. The first VGPR holds the first 32 bits and second holding
469+
// the last 32 bits.
470+
//
471+
// For the register (i.e., element) dimension, these 4 elements are along
472+
// the matrix C's M dimension, with 4 consecutive elements spanning 4
473+
// rows.
474+
//
475+
// For the lane (i.e., thread) dimension, these threads are along the
476+
// matrix C's N dimension, with each group of 16 consecutive threads
477+
// covering a whole adjacent row. Unlike the <=32 bit cases, there's no
478+
// row gaps between the groups.
456479
tileLayout = LinearLayout(
457-
{{kRegister, {{1, 0}, {2, 0}}},
458-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
480+
{{kRegister, {{0, 4}, {0, 8}}},
481+
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 1}, {0, 2}}}},
459482
{outDimNames[order[0]], outDimNames[order[1]]});
483+
// For mfma.transposed layout, the element ownership among threads are
484+
// "transposed" within each warp.
485+
if (getIsTransposed())
486+
tileLayout = LinearLayout(
487+
{{kRegister, {{4, 0}, {8, 0}}},
488+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}},
489+
{outDimNames[order[0]], outDimNames[order[1]]});
490+
}
460491
}
461492

462493
// Instead of defining the layout on a CTA tile and using the

python/test/unit/language/test_matmul.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
112112
pytest.skip("Float8 requires compute capability >= 9")
113113
if (dtype_src_str == "float64") != (dtype_dst_str == "float64"):
114114
pytest.skip("Skipping unsupported case")
115-
if dtype_src_str == "float64" and not is_cuda():
116-
pytest.skip("Float64 not supported on HIP yet")
117115
if "float32" in dtype_src_str and dtype_dst_str == "float16":
118116
pytest.skip("Skipping unsupported case")
119117
if "float32" == dtype_src_str and NUM_CTAS > 1:

test/TritonGPU/invalid-attributes.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@
7474

7575
// -----
7676

77+
// expected-error@+1 {{element type must be f64, f32, i32, or none}}
78+
#mfma = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 1, 1], instrShape = [16, 16], isTransposed = false, elementType = f16}>
79+
80+
// -----
81+
7782
// expected-error@+1 {{interval values must all be power of two}}
7883
#shared = #ttg.padded_shared<[3:+2]>
7984

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ struct DotOpMFMAConversionHelper {
145145
Value zero;
146146
if (elemType.isInteger(32))
147147
zero = b.i32_val(0);
148+
else if (elemType.isF64())
149+
zero = b.f64_val(0.0);
148150
else
149151
zero = b.f32_val(0.0);
150152
auto cond = b.icmp_ult(laneId, b.i32_val(subBlockSize));
@@ -462,9 +464,9 @@ struct DotOpMFMAConversionHelper {
462464
}
463465

464466
// Step 2: process rawElems based on element type
465-
// Note that for f32 input and XF32 is not allowed, nothing needs to
466-
// be done and rawElems is inserted into the ValueTable directly
467-
if (type.isF32() && !allowXF32) {
467+
// Note that for f32/fp64 input and XF32 is not allowed, nothing needs
468+
// to be done and rawElems is inserted into the ValueTable directly
469+
if ((type.isF32() || type.isF64()) && !allowXF32) {
468470
dotOpVals[{b, nonK, kBaseVec}] =
469471
tb.extract_element(type, rawElems, tb.i32_val(0));
470472
} else {

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ createTmpLayout(triton::gpu::DistributedEncodingTrait layout,
3939
if (auto src = dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(layout))
4040
return triton::gpu::AMDMfmaEncodingAttr::get(
4141
ctx, src.getVersion(), warpsPerCTA, src.getMDim(), src.getNDim(),
42-
src.getIsTransposed(), src.getCTALayout());
42+
src.getIsTransposed(), src.getCTALayout(), src.getElementType());
4343
if (auto src = dyn_cast<triton::gpu::AMDWmmaEncodingAttr>(layout))
4444
return triton::gpu::AMDWmmaEncodingAttr::get(
4545
ctx, src.getVersion(), src.getIsTransposed(), warpsPerCTA,

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,15 @@ chooseMfmaInstruction(Location loc, int mfmaVersion, RankedTensorType cType,
155155
} else {
156156
int minSize = std::min(M, N);
157157
if (minSize >= 32) {
158-
mDim = 32;
159-
nDim = 32;
158+
// On CNDA2-4, if the element type is f64, we use 16x16 intrinsic as
159+
// there's no 32x32 intrinsic.
160+
if (aElemType.isF64() || bElemType.isF64()) {
161+
mDim = 16;
162+
nDim = 16;
163+
} else {
164+
mDim = 32;
165+
nDim = 32;
166+
}
160167
}
161168
if (minSize >= 16 && minSize < 32) {
162169
mDim = 16;
@@ -450,19 +457,22 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
450457
auto warpsPerTile =
451458
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});
452459

460+
Type mfmaAccType;
461+
if (oldRetType.getElementType().isIntOrIndex())
462+
mfmaAccType = rewriter.getIntegerType(32);
463+
else if (oldRetType.getElementType().isF64())
464+
mfmaAccType = rewriter.getF64Type();
465+
else
466+
mfmaAccType = rewriter.getF32Type();
467+
453468
// Use transposed mfma layout to enable larger vectorization for global
454469
// store instructions.
455470
auto aElemTy = mfmaInstr->aElementType;
456471
ttg::AMDMfmaEncodingAttr mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
457472
oldRetType.getContext(),
458473
/*version*/ mfmaVersion, warpsPerTile,
459-
/*instrShape*/ mDim, nDim, /*isTransposed=*/true, CTALayout);
460-
461-
Type mfmaAccType;
462-
if (oldRetType.getElementType().isIntOrIndex())
463-
mfmaAccType = rewriter.getIntegerType(32);
464-
else
465-
mfmaAccType = rewriter.getF32Type();
474+
/*instrShape*/ mDim, nDim, /*isTransposed=*/true, CTALayout,
475+
mfmaAccType);
466476

467477
// convert accumulator
468478
auto oldAcc = dotOp.getC();
@@ -657,7 +667,7 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
657667
// for global store instructions.
658668
auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
659669
ctx, /*version=*/mfmaVersion, mfmaWarpsPerCTA, /*instrShape=*/mDim,
660-
nDim, /*isTransposed=*/true, ctaLayout);
670+
nDim, /*isTransposed=*/true, ctaLayout, oldRetType.getElementType());
661671

662672
auto newRetType = RankedTensorType::get(
663673
oldRetType.getShape(), oldRetType.getElementType(), mfmaEnc);
@@ -815,7 +825,8 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
815825
// for global store instructions.
816826
auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
817827
ctx, /*verison=*/mfmaVersion, warpsPerTile,
818-
/*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout);
828+
/*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout,
829+
oldRetType.getElementType());
819830

820831
auto newRetType =
821832
RankedTensorType::get(oldShape, oldRetType.getElementType(), mfmaEnc);

third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ MfmaDatabase::MfmaDatabase(MLIRContext *context) {
127127
TRITON_MFMA_v2to4(m, n, aET, bET, symbol, k, kBase)
128128

129129
Builder b(context);
130+
auto f64T = b.getF64Type();
130131
auto f32T = b.getF32Type();
131132
auto tf32T = b.getTF32Type();
132133
auto f16T = b.getF16Type();
@@ -139,6 +140,13 @@ MfmaDatabase::MfmaDatabase(MLIRContext *context) {
139140
auto fp4T = b.getType<Float4E2M1FNType>();
140141

141142
mfmaMap = {
143+
// f64 inputs
144+
// mfma_f64_16x16x4f64
145+
TRITON_MFMA_v2to4(16, 16, f64T, f64T, mfma_f64_16x16x4f64, 4, 1),
146+
// mfma_f64_4x4x4f64
147+
TRITON_MFMA_v2to4(4, 4, f64T, f64T, mfma_f64_4x4x4f64, 16, 1),
148+
TRITON_MFMA_v2to4(4, 16, f64T, f64T, mfma_f64_4x4x4f64, 4, 1),
149+
TRITON_MFMA_v2to4(16, 4, f64T, f64T, mfma_f64_4x4x4f64, 4, 1),
142150
// f32 inputs
143151
// mfma_f32_32x32x2f32
144152
TRITON_MFMA_v1to4(32, 32, f32T, f32T, mfma_f32_32x32x2f32, 2, 1),

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,14 @@ class AMDMfmaLayoutTest : public AMDLayoutTest {
473473
ArrayRef<unsigned> warpsPerCTA) {
474474
return triton::gpu::AMDMfmaEncodingAttr::get(
475475
&ctx, /*version=*/2, warpsPerCTA, mDim, nDim,
476-
/*isTransposed=*/false, ctaLayout);
476+
/*isTransposed=*/false, ctaLayout, std::nullopt);
477477
}
478478

479479
triton::gpu::AMDMfmaEncodingAttr
480480
createTransposedMFMA(int mDim, int nDim, ArrayRef<unsigned> warpsPerCTA) {
481481
return triton::gpu::AMDMfmaEncodingAttr::get(
482482
&ctx, /*version=*/2, warpsPerCTA, mDim, nDim,
483-
/*isTransposed=*/true, ctaLayout);
483+
/*isTransposed=*/true, ctaLayout, std::nullopt);
484484
}
485485
};
486486

0 commit comments

Comments
 (0)