Skip to content

Commit bceed1f

Browse files
authored
[AMD] Enable larger kBase mfma instructions for smaller kWidth (triton-lang#6540)
Previously, we assume that kBase <= kWidth. This PR supports the case that kBase > kWidth. Now the two cases are unified and each thread simply packs kBase elements into a vector and process the type for mfma instructions. For FA kernels, we need smaller kWidth to avoid in-warp shuffling for the layout conversion from #mma to #dotOp for the 2nd dot. This was done in triton-lang#6532, which has to select mfma instructions with smaller kBase. This PR enables the largest kBase mfma instructions to be selected with kWidth=4.
1 parent 1294776 commit bceed1f

File tree

2 files changed

+131
-143
lines changed

2 files changed

+131
-143
lines changed

test/TritonGPU/amd/mfma-double-rate.mlir

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
6161

6262
// -----
6363

64-
// When kWidth is set to 4, generate single rated mfma instructions.
65-
// In a future PR, such cases will still generate double rated mfma instructions with kWidth = 4.
64+
// When kWidth is set to 4, still generate double rated mfma instructions.
6665

6766
// CHECK-LABEL:mfma_16x16x32_f16
6867

@@ -74,7 +73,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
7473
%q: tensor<128x128xf16, #dotOp0>,
7574
%k: tensor<128x128xf16, #dotOp1>) {
7675
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
77-
// CHECK: rocdl.mfma.f32.16x16x16f16 {{.*}} : (vector<4xf16>, vector<4xf16>
76+
// CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16>
7877
%qk = tt.dot %q, %k, %cst : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
7978
tt.return
8079
}
@@ -92,7 +91,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9291
%q: tensor<128x128xbf16, #dotOp0>,
9392
%k: tensor<128x128xbf16, #dotOp1>) {
9493
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
95-
// CHECK: rocdl.mfma.f32.16x16x16bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>
94+
// CHECK: rocdl.mfma.f32.16x16x32.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>
9695
%qk = tt.dot %q, %k, %cst : tensor<128x128xbf16, #dotOp0> * tensor<128x128xbf16, #dotOp1> -> tensor<128x128xf32, #mma>
9796
tt.return
9897
}
@@ -110,7 +109,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
110109
%q: tensor<128x128xf16, #dotOp0>,
111110
%k: tensor<128x128xf16, #dotOp1>) {
112111
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
113-
// CHECK: rocdl.mfma.f32.32x32x8f16 {{.*}} : (vector<4xf16>, vector<4xf16>
112+
// CHECK: rocdl.mfma.f32.32x32x16.f16 {{.*}} : (vector<8xf16>, vector<8xf16>
114113
%qk = tt.dot %q, %k, %cst : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma>
115114
tt.return
116115
}
@@ -128,7 +127,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
128127
%q: tensor<128x128xbf16, #dotOp0>,
129128
%k: tensor<128x128xbf16, #dotOp1>) {
130129
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
131-
// CHECK: rocdl.mfma.f32.32x32x8bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>
130+
// CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>
132131
%qk = tt.dot %q, %k, %cst : tensor<128x128xbf16, #dotOp0> * tensor<128x128xbf16, #dotOp1> -> tensor<128x128xf32, #mma>
133132
tt.return
134133
}

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 126 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "TritonAMDGPUTransforms/MfmaGroup.h"
2626
#include "Utility.h"
2727
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
28+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2829
#include "llvm/ADT/TypeSwitch.h"
2930

3031
using namespace mlir;
@@ -281,21 +282,8 @@ struct DotOpMFMAConversionHelper {
281282
auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
282283
auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
283284
int kWidth = aEncoding.getKWidth();
284-
// If the kBase of the selected mfma instruction is larger than
285-
// kWidth of the operand, it means the shape is large enough to
286-
// use double rated mfma, but we (AccelerateAMDMatmul pass) choose
287-
// to use single rated mfma.
288-
if (kBase > kWidth) {
289-
int kDimOperandSizeNew = 64 / mDim * kWidth;
290-
maybeMfmaIntrinsic = MfmaIntrinsic::selectFor(
291-
mfmaVersion, mDim, nDim, kDimOperandSizeNew, elemTyA, elemTyB,
292-
/*withScale=*/false, allowXF32);
293-
if (failed(maybeMfmaIntrinsic))
294-
llvm::report_fatal_error("No match found in MFMA database\n");
295-
}
296285

297286
intrinsicName = maybeMfmaIntrinsic->name;
298-
kBase = maybeMfmaIntrinsic->kBase;
299287

300288
// If we are using XF32, the kWidth (and kBase) is double that of F32.
301289
if (aTensorTy.getElementType().isF32() && allowXF32)
@@ -335,6 +323,7 @@ struct DotOpMFMAConversionHelper {
335323
const int subBlocks =
336324
getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim);
337325
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
326+
int numVecInKBase = numRepK * kWidth / kBase;
338327

339328
Value firstMfma;
340329
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
@@ -350,18 +339,14 @@ struct DotOpMFMAConversionHelper {
350339
tb.i32_val(v));
351340
}
352341
acc = zeroAuxiliarBlocks(subBlocks, acc);
353-
for (int k = 0; k < numRepK; k++) {
354-
for (int kPack = 0; kPack < kWidth / kBase; ++kPack) {
355-
acc = mfmaLayout.getIsTransposed()
356-
? generateMFMAOp(intrinsicName,
357-
operandB[kPack][{b, n, k}],
358-
operandA[kPack][{b, m, k}], acc)
359-
: generateMFMAOp(intrinsicName,
360-
operandA[kPack][{b, m, k}],
361-
operandB[kPack][{b, n, k}], acc);
362-
if (!firstMfma)
363-
firstMfma = acc;
364-
}
342+
for (int k = 0; k < numVecInKBase; k++) {
343+
acc = mfmaLayout.getIsTransposed()
344+
? generateMFMAOp(intrinsicName, operandB[{b, n, k}],
345+
operandA[{b, m, k}], acc)
346+
: generateMFMAOp(intrinsicName, operandA[{b, m, k}],
347+
operandB[{b, n, k}], acc);
348+
if (!firstMfma)
349+
firstMfma = acc;
365350
}
366351
acc = reduceSubBlocks(subBlocks, acc);
367352
adjustAccForSmallKDim(fc, acc, dstElemTy, b, m, n, numRepM, numRepN,
@@ -387,109 +372,120 @@ struct DotOpMFMAConversionHelper {
387372
return success();
388373
}
389374

390-
/// Extract vector from rawElems based on kWidth and kBase
391-
/// rawElems is a vector of kWidth elements. We need to prepare vector(s) of
392-
/// kBase elements for each mfma instruction
393-
SmallVector<Value> extractOperands(Value rawElems, int kWidth, int kBase,
394-
Type type, bool preserveBF16,
395-
bool isConstantScale = false) const {
375+
/// Process the elements in rawElems and prepare a vector for mfma input.
376+
/// rawElems is a vector of kBase elements. Each element is of the raw
377+
/// element type from the input. We need to prepare a vector of kBase
378+
/// elements of appropriate element type required by mfma instructions.
379+
Value prepareOperands(Value rawElems, int kBase, Type type, bool preserveBF16,
380+
bool isConstantScale = false) const {
396381
auto b = TritonLLVMOpBuilder(loc, rewriter);
397-
int kpack = kWidth / kBase;
398-
SmallVector<Value> results;
382+
Value results;
383+
384+
// Construct a vector type of kBase elements with desired type
399385
auto vecTy = vec_ty(type, kBase);
400386
if (type.isBF16() && !preserveBF16)
401387
vecTy = vec_ty(i16_ty, kBase);
402-
for (int k = 0; k < kpack; ++k) {
403-
Value vec = b.undef(vecTy);
404-
for (int elemId = 0; elemId < kBase; ++elemId) {
405-
auto val =
406-
b.extract_element(type, rawElems, b.i32_val(elemId + k * kBase));
407-
if (type.isBF16() && !preserveBF16) {
408-
// rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type
409-
auto cast = b.bitcast(val, i16_ty);
410-
vec = b.insert_element(vecTy, vec, cast, b.i32_val(elemId));
411-
} else {
412-
vec = b.insert_element(vecTy, vec, val, b.i32_val(elemId));
413-
}
388+
Value vec = b.undef(vecTy);
389+
390+
// For each element in rawElems, extract the element as the desired type,
391+
// bitcast it if needed, and insert it into vec.
392+
for (int elemId = 0; elemId < kBase; ++elemId) {
393+
auto val = b.extract_element(type, rawElems, b.i32_val(elemId));
394+
if (type.isBF16() && !preserveBF16) {
395+
// rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type
396+
auto cast = b.bitcast(val, i16_ty);
397+
vec = b.insert_element(vecTy, vec, cast, b.i32_val(elemId));
398+
} else {
399+
vec = b.insert_element(vecTy, vec, val, b.i32_val(elemId));
414400
}
415-
if (type.getIntOrFloatBitWidth() == 8) {
416-
if (1 == kBase) {
417-
// This is only for the scale operands of scaled mfma on CDNA4
418-
if (isConstantScale) {
419-
// If the scale is constant(created by arith::ConstantOp), it will
420-
// be put in a sgpr instead of vgpr. In that case, instead of
421-
// vgpr[7:0], the instruction reads sgpr[30:23] as the scale value.
422-
// So we need to manually left shift the scale by 23 bits to meet
423-
// the requirement.
424-
results.push_back(b.shl(
425-
i32_ty, b.zext(i32_ty, b.bitcast(vec, i8_ty)), b.i32_val(23)));
426-
} else {
427-
results.push_back(b.zext(i32_ty, b.bitcast(vec, i8_ty)));
428-
}
401+
}
402+
403+
// Now we have a vector of kBase elements of desired type.
404+
// Then we need to prepare vec for results.
405+
if (type.getIntOrFloatBitWidth() == 8) {
406+
if (1 == kBase) {
407+
// This is only for the scale operands of scaled mfma on CDNA4
408+
if (isConstantScale) {
409+
// If the scale is constant(created by arith::ConstantOp), it will
410+
// be put in a sgpr instead of vgpr. In that case, instead of
411+
// vgpr[7:0], the instruction reads sgpr[30:23] as the scale value.
412+
// So we need to manually left shift the scale by 23 bits to meet
413+
// the requirement.
414+
results = b.shl(i32_ty, b.zext(i32_ty, b.bitcast(vec, i8_ty)),
415+
b.i32_val(23));
416+
} else {
417+
results = b.zext(i32_ty, b.bitcast(vec, i8_ty));
429418
}
430-
if (4 == kBase)
431-
// This is for int8 on pre- CDNA3 GPUs
432-
results.push_back(b.bitcast(vec, i32_ty));
433-
if (8 == kBase)
434-
results.push_back(b.bitcast(vec, i64_ty));
435-
if (16 == kBase)
436-
// This is only for the operands of scaled mfma on CDNA4
437-
results.push_back(b.bitcast(vec, vec_ty(i32_ty, 4)));
438-
if (32 == kBase)
439-
results.push_back(b.bitcast(vec, vec_ty(i32_ty, 8)));
440-
} else {
441-
results.push_back(vec);
442419
}
420+
if (4 == kBase)
421+
// This is for int8 on pre- CDNA3 GPUs
422+
results = b.bitcast(vec, i32_ty);
423+
if (8 == kBase)
424+
results = b.bitcast(vec, i64_ty);
425+
if (16 == kBase)
426+
// This is only for the operands of scaled mfma on CDNA4
427+
results = b.bitcast(vec, vec_ty(i32_ty, 4));
428+
if (32 == kBase)
429+
results = b.bitcast(vec, vec_ty(i32_ty, 8));
430+
} else {
431+
results = vec;
443432
}
444433
return results;
445434
}
446435

447436
/// Converts dot operand structure to value table and converts types
448437
/// appropriate for mfma instructions
449-
virtual SmallVector<ValueTable> getValuesFromDotOperandLayoutStruct(
450-
Value value, int batch, int n0, int n1, int kWidth, int kBase, Type type,
451-
bool allowXF32, bool preserveBF16, bool isConstantScale = false) const {
438+
virtual ValueTable getValuesFromDotOperandLayoutStruct(
439+
Value value, int batch, int nonKRep, int kRepInKWidth, int kWidth,
440+
int kBase, Type type, bool allowXF32, bool preserveBF16,
441+
bool isConstantScale = false) const {
452442
auto tb = TritonLLVMOpBuilder(loc, rewriter);
453443
auto elems = unpackLLElements(loc, value, rewriter);
454-
int kpack = kWidth / kBase;
455-
SmallVector<ValueTable> dotOpVals(kpack);
444+
// number of kBase-element vectors
445+
int numVecInKBase = kRepInKWidth * kWidth / kBase;
446+
ValueTable dotOpVals;
447+
448+
SmallVector<int64_t> bounds = {batch, nonKRep, numVecInKBase, kBase};
449+
SmallVector<int64_t> strides = computeStrides(bounds);
456450
for (int b = 0; b < batch; ++b) {
457-
for (int i = 0; i < n0; i++) {
458-
for (int j = 0; j < n1; j++) {
451+
for (int nonK = 0; nonK < nonKRep; nonK++) {
452+
for (int kBaseVec = 0; kBaseVec < numVecInKBase; kBaseVec++) {
453+
// For each kBase-element vector
454+
455+
// Step 1: construct each kBase-element vector by
456+
// - extracting kBase elements from elems and
457+
// - putting them into a kBase-element vector, i.e. rawElems
459458
Type elemTy = typeConverter->convertType(type);
460-
Type ty = vec_ty(elemTy, kWidth);
459+
Type ty = vec_ty(elemTy, kBase);
461460
Value rawElems = tb.undef(ty);
462-
for (int k = 0; k < kWidth; ++k) {
463-
rawElems = tb.insert_element(
464-
ty, rawElems,
465-
elems[kWidth * n1 * n0 * b + kWidth * n1 * i + kWidth * j + k],
466-
tb.i32_val(k));
461+
for (int k = 0; k < kBase; ++k) {
462+
auto index = linearize({b, nonK, kBaseVec, k}, strides);
463+
rawElems =
464+
tb.insert_element(ty, rawElems, elems[index], tb.i32_val(k));
467465
}
468466

469-
Value convertedElems;
467+
// Step 2: process rawElems based on element type
468+
// Note that for f32 input and XF32 is not allowed, nothing needs to
469+
// be done and rawElems is inserted into the ValueTable directly
470470
if (type.isF32() && !allowXF32) {
471-
for (int k = 0; k < kpack; ++k)
472-
dotOpVals[k][{b, i, j}] =
473-
tb.extract_element(type, rawElems, tb.i32_val(k));
471+
dotOpVals[{b, nonK, kBaseVec}] =
472+
tb.extract_element(type, rawElems, tb.i32_val(0));
474473
} else {
475-
SmallVector<Value> vals;
474+
Value vals;
476475
if (type.isF32() && allowXF32) {
477-
vals = extractOperands(rawElems, kWidth, kBase, f32_ty,
478-
preserveBF16);
476+
vals = prepareOperands(rawElems, kBase, f32_ty, preserveBF16);
479477
} else if (type.getIntOrFloatBitWidth() == 8) {
480-
vals = extractOperands(rawElems, kWidth, kBase, i8_ty,
481-
preserveBF16, isConstantScale);
478+
vals = prepareOperands(rawElems, kBase, i8_ty, preserveBF16,
479+
isConstantScale);
482480
} else if (type.isBF16()) {
483-
vals = extractOperands(rawElems, kWidth, kBase, bf16_ty,
484-
preserveBF16);
481+
vals = prepareOperands(rawElems, kBase, bf16_ty, preserveBF16);
485482
} else {
486483
assert(type.isF16() && "Unsupported data type");
487-
vals = extractOperands(rawElems, kWidth, kBase, f16_ty,
488-
preserveBF16);
489-
}
490-
for (int k = 0; k < kpack; ++k) {
491-
dotOpVals[k][{b, i, j}] = vals[k];
484+
vals = prepareOperands(rawElems, kBase, f16_ty, preserveBF16);
492485
}
486+
487+
// Step 3: Insert the processed vals into the ValueTable
488+
dotOpVals[{b, nonK, kBaseVec}] = vals;
493489
}
494490
}
495491
}
@@ -638,8 +634,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
638634

639635
// Scales have the same replica distributions as their corresponding
640636
// operands.
641-
SmallVector<ValueTable> operandAScale;
642-
SmallVector<ValueTable> operandBScale;
637+
ValueTable operandAScale;
638+
ValueTable operandBScale;
643639
if (existBothScales) {
644640
auto aScaleTensorTy = cast<RankedTensorType>(aScale.getType());
645641
operandAScale = getValuesFromDotOperandLayoutStruct(
@@ -663,6 +659,7 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
663659
const int subBlocks =
664660
getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim);
665661
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
662+
int numVecInKBase = numRepK * aKWidth / aKBase;
666663

667664
Value firstMfma;
668665
auto tb = TritonLLVMOpBuilder(loc, rewriter);
@@ -679,44 +676,36 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
679676
tb.i32_val(v));
680677
}
681678
acc = zeroAuxiliarBlocks(subBlocks, acc);
682-
for (int k = 0; k < numRepK; k++) {
683-
for (int kPack = 0; kPack < aKWidth / aKBase; ++kPack) {
684-
if (existBothScales) {
685-
if (mfmaLayout.getIsTransposed()) {
686-
acc = generateScaledMFMAOp(intrinsicName,
687-
operandB[kPack][{b, n, k}],
688-
operandA[kPack][{b, m, k}], acc,
689-
operandBScale[kPack][{b, n, k}],
690-
operandAScale[kPack][{b, m, k}],
691-
maybeMfmaIntrinsic->bElementType,
692-
maybeMfmaIntrinsic->aElementType);
693-
} else {
694-
acc = generateScaledMFMAOp(intrinsicName,
695-
operandA[kPack][{b, m, k}],
696-
operandB[kPack][{b, n, k}], acc,
697-
operandAScale[kPack][{b, m, k}],
698-
operandBScale[kPack][{b, n, k}],
699-
maybeMfmaIntrinsic->aElementType,
700-
maybeMfmaIntrinsic->bElementType);
701-
}
679+
for (int k = 0; k < numVecInKBase; k++) {
680+
if (existBothScales) {
681+
if (mfmaLayout.getIsTransposed()) {
682+
acc = generateScaledMFMAOp(
683+
intrinsicName, operandB[{b, n, k}], operandA[{b, m, k}],
684+
acc, operandBScale[{b, n, k}], operandAScale[{b, m, k}],
685+
maybeMfmaIntrinsic->bElementType,
686+
maybeMfmaIntrinsic->aElementType);
687+
} else {
688+
acc = generateScaledMFMAOp(
689+
intrinsicName, operandA[{b, m, k}], operandB[{b, n, k}],
690+
acc, operandAScale[{b, m, k}], operandBScale[{b, n, k}],
691+
maybeMfmaIntrinsic->aElementType,
692+
maybeMfmaIntrinsic->bElementType);
693+
}
694+
} else {
695+
if (mfmaLayout.getIsTransposed()) {
696+
acc = generateScaledMFMAOp(intrinsicName, operandB[{b, n, k}],
697+
operandA[{b, m, k}], acc,
698+
maybeMfmaIntrinsic->bElementType,
699+
maybeMfmaIntrinsic->aElementType);
702700
} else {
703-
if (mfmaLayout.getIsTransposed()) {
704-
acc = generateScaledMFMAOp(intrinsicName,
705-
operandB[kPack][{b, n, k}],
706-
operandA[kPack][{b, m, k}], acc,
707-
maybeMfmaIntrinsic->bElementType,
708-
maybeMfmaIntrinsic->aElementType);
709-
} else {
710-
acc = generateScaledMFMAOp(intrinsicName,
711-
operandA[kPack][{b, m, k}],
712-
operandB[kPack][{b, n, k}], acc,
713-
maybeMfmaIntrinsic->aElementType,
714-
maybeMfmaIntrinsic->bElementType);
715-
}
701+
acc = generateScaledMFMAOp(intrinsicName, operandA[{b, m, k}],
702+
operandB[{b, n, k}], acc,
703+
maybeMfmaIntrinsic->aElementType,
704+
maybeMfmaIntrinsic->bElementType);
716705
}
717-
if (!firstMfma)
718-
firstMfma = acc;
719706
}
707+
if (!firstMfma)
708+
firstMfma = acc;
720709
}
721710
acc = reduceSubBlocks(subBlocks, acc);
722711
adjustAccForSmallKDim(fc, acc, dstElemTy, b, m, n, numRepM, numRepN,

0 commit comments

Comments
 (0)