Skip to content

Commit b611ccd

Browse files
[AMD] Add error message for invalid MFMA tile k-dimension size (#8301)
This PR adds additional remarks to help with debugging dot operation variant mapping to MFMA and FMA intrinsics. ## Added Remarks - Each failed call of `chooseMfmaInstruction` will emit a remark informing that MFMA intrinsic selection failed with argument info. In particular, there is a remark for both the initial selection failure case and the further failure case where the k-dimension of the tile is not a multiple of the k-dimension of the intrinsic. - Will emit a generic remark if unable to map the `tt::DotOp` to a `V_MFMA_*_F8F6F4` intrinsic. - Will emit a generic remark if unable to map the `tt::DotOp` to any MFMA intrinsic. - Will emit a generic remark that a `tt::DotScaledOp` is being decomposed into a `tt::DotOp` with explicit scaling. This occurs if `tt::DotScaledOp` cannot be mapped to a `V_MFMA_*_F8F6F4` intrinsic. - Will emit a generic remark that a `tt::DotOp` is being mapped to an FMA intrinsic, which appears to never fail unless the operation is already mapped.
1 parent 8fc6ccd commit b611ccd

File tree

2 files changed

+98
-29
lines changed

2 files changed

+98
-29
lines changed

test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" --verify-diagnostics | FileCheck %s --check-prefixes MFMA0,CHECK
2-
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" --verify-diagnostics | FileCheck %s --check-prefixes MFMA16,CHECK
1+
// RUN: split-file %s %t
2+
// RUN: cat %t/common.mlir %t/mfma0.mlir > %t/run-mfma0.mlir
3+
// RUN: triton-opt %t/run-mfma0.mlir -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" --verify-diagnostics | FileCheck %t/run-mfma0.mlir --check-prefixes=MFMA0,CHECK
4+
// RUN: cat %t/common.mlir %t/mfma16.mlir > %t/run-mfma16.mlir
5+
// RUN: triton-opt %t/run-mfma16.mlir -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" --verify-diagnostics | FileCheck %t/run-mfma16.mlir --check-prefixes=MFMA16,CHECK
6+
7+
//--- common.mlir
38

49
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
510
// CHECK-LABEL: mfma_dot_fp8e5m2_fp8e4m3fn
@@ -64,6 +69,28 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
6469

6570
// -----
6671

72+
// MFMA0-NOT: amd_mfma
73+
// MFMA16-NOT: amd_mfma
74+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
75+
// CHECK-LABEL: mfma_dot_small_k
76+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
77+
tt.func public @mfma_dot_small_k(
78+
%arg0: tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
79+
%arg1: tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
80+
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked> ) {
81+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
82+
// expected-remark @+2 {{Unable to select MFMA intrinsic}}
83+
// expected-remark @+1 {{Attempting to map dot operation to FMA intrinsic.}}
84+
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
85+
tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
86+
tt.return
87+
}
88+
}
89+
90+
// -----
91+
92+
//--- mfma0.mlir
93+
6794
// MFMA0-NOT: amd_mfma
6895
// MFMA16: #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 2], instrShape = [16, 16, 16], isTransposed = true}>
6996
// CHECK-LABEL: small_m_size_fma
@@ -74,26 +101,26 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
74101
%b: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>)
75102
-> tensor<1x128xf32, #blocked> {
76103
%zero_f32 = arith.constant dense<0.000000e+00> : tensor<1x128xf32, #blocked>
104+
// expected-remark @+2 {{Unable to select MFMA intrinsic}}
105+
// expected-remark @+1 {{Attempting to map dot operation to FMA intrinsic.}}
77106
%result = tt.dot %a, %b, %zero_f32 : tensor<1x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1x128xf32, #blocked>
78107
tt.return %result : tensor<1x128xf32, #blocked>
79108
}
80109
}
81110

82-
83-
// -----
111+
//--- mfma16.mlir
84112

85113
// MFMA0-NOT: amd_mfma
86-
// MFMA16-NOT: amd_mfma
87-
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
88-
// CHECK-LABEL: mfma_dot_small_k
89-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
90-
tt.func public @mfma_dot_small_k(
91-
%arg0: tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
92-
%arg1: tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
93-
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked> ) {
94-
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
95-
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
96-
tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
97-
tt.return
114+
// MFMA16: #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 2], instrShape = [16, 16, 16], isTransposed = true}>
115+
// CHECK-LABEL: small_m_size_fma
116+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 64], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
117+
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 64 : i32} {
118+
tt.func @small_m_size_fma(
119+
%a: tensor<1x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
120+
%b: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>)
121+
-> tensor<1x128xf32, #blocked> {
122+
%zero_f32 = arith.constant dense<0.000000e+00> : tensor<1x128xf32, #blocked>
123+
%result = tt.dot %a, %b, %zero_f32 : tensor<1x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1x128xf32, #blocked>
124+
tt.return %result : tensor<1x128xf32, #blocked>
98125
}
99126
}

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,37 @@ chooseMfmaInstruction(Location loc, int mfmaVersion, RankedTensorType cType,
199199
aElemType, bElemType, withScale, allowXF32);
200200

201201
// Fallback to FMA if the M/N dim is not supported by MFMA.
202-
if (failed(maybeMfmaIntrinsic))
202+
if (failed(maybeMfmaIntrinsic)) {
203+
mlir::emitRemark(loc) << "Unable to select MFMA intrinsic for the request: "
204+
<< "version=" << mfmaVersion << ", result-shape=("
205+
<< M << "x" << N << "), selected-tiles=(" << mDim
206+
<< "x" << nDim << "), inputKSize=" << inputKSize
207+
<< ", aElemType=" << aElemType
208+
<< ", bElemType=" << bElemType
209+
<< ", withScale=" << (withScale ? "true" : "false")
210+
<< ", allowXF32=" << (allowXF32 ? "true" : "false")
211+
<< (enforcedNonKDim != 0
212+
? (llvm::Twine(", enforcedNonKDim=") +
213+
llvm::Twine(enforcedNonKDim))
214+
.str()
215+
: "");
203216
return failure();
217+
}
204218

205219
kDim = maybeMfmaIntrinsic->kDim;
206220
assert(kDim != 0);
207221
assert(enforcedNonKDim != 0 || (M % mDim == 0 && N % nDim == 0));
208222
// If inputKSize % kDim != 0 (including the case where inputKSize < kDim),
209223
// this layout will introduce data duplication.
210-
if (inputKSize % kDim != 0)
224+
if (inputKSize % kDim != 0) {
225+
mlir::emitRemark(loc)
226+
<< "Unable to select MFMA intrinsic '" << maybeMfmaIntrinsic->name
227+
<< "' as MFMA intrinsic k-dimension size kDim=" << kDim
228+
<< ", which is not a multiple of tile k-dimension size inputKSize="
229+
<< inputKSize
230+
<< ". Using this intrinsic would introduce data duplication.";
211231
return failure();
232+
}
212233
return maybeMfmaIntrinsic;
213234
}
214235

@@ -548,11 +569,15 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
548569
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, withScale);
549570
if (failed(mfmaInstr)) {
550571
if (!withScale) {
551-
return failure();
572+
return rewriter.notifyMatchFailure(
573+
dotOp,
574+
"Unable to choose preferable MFMA intrinsic for dot operation.");
552575
}
553576
mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, false);
554-
if (failed(mfmaInstr))
555-
return failure();
577+
if (failed(mfmaInstr)) {
578+
return rewriter.notifyMatchFailure(
579+
dotOp, "Unable to choose MFMA intrinsic for dot operation.");
580+
}
556581

557582
withScale = false;
558583
}
@@ -769,7 +794,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
769794
FailureOr<MfmaIntrinsic> mfmaInstr =
770795
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, useFp16);
771796
if (failed(mfmaInstr))
772-
return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic");
797+
return rewriter.notifyMatchFailure(
798+
dotOp, "Unable to choose MFMA intrinsic for scaled dot operation.");
773799

774800
if (useFp16) {
775801
dotOp.emitRemark(
@@ -895,6 +921,13 @@ class DecomposeAMDScaledBlocked final : public ttg::DecomposeScaledBlocked {
895921
: ttg::DecomposeScaledBlocked(context, benefit) {}
896922
using TensorValue = TypedValue<RankedTensorType>;
897923

924+
LogicalResult matchAndRewrite(tt::DotScaledOp dotOp,
925+
PatternRewriter &rewriter) const override {
926+
dotOp.emitRemark() << "Decomposing scaled dot operation into regular dot "
927+
"operation with explicit scaling.";
928+
return ttg::DecomposeScaledBlocked::matchAndRewrite(dotOp, rewriter);
929+
}
930+
898931
RankedTensorType getScaleType(RankedTensorType vType, int32_t kDim,
899932
bool isFp4) const {
900933
if (!isFp4)
@@ -1018,9 +1051,11 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
10181051
// Choose a suitable Scaled MFMA instruction for this scaled dot op.
10191052
FailureOr<MfmaIntrinsic> mfmaInstr =
10201053
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim);
1021-
if (failed(mfmaInstr))
1054+
if (failed(mfmaInstr)) {
10221055
return rewriter.notifyMatchFailure(dotOp,
1023-
"cannot choose scaled mfma intrinsic");
1056+
"Unable to choose preferable MFMA "
1057+
"intrinsic for scaled dot operation.");
1058+
}
10241059

10251060
auto mDim = mfmaInstr->mDim;
10261061
auto nDim = mfmaInstr->nDim;
@@ -1474,7 +1509,8 @@ class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
14741509
FailureOr<WmmaIntrinsic> wmmaInstr =
14751510
chooseWmmaInstruction(dotOp, operandTypes, wmmaVersion, nonKDim);
14761511
if (failed(wmmaInstr)) {
1477-
return failure();
1512+
return rewriter.notifyMatchFailure(
1513+
dotOp, "Unable to choose WMMA intrinsic for dot operation.");
14781514
}
14791515

14801516
auto mDim = wmmaInstr->mDim;
@@ -1625,7 +1661,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16251661
LogicalResult tryAccelerateF16WithVDot(DotOp dotOp, PatternRewriter &rewriter,
16261662
const DotElTypes &dotTypes) const {
16271663
if (!AMD::supportsVDot(arch))
1628-
return failure();
1664+
return rewriter.notifyMatchFailure(
1665+
dotOp, "Target architecture does not support V_DOT instruction.");
16291666

16301667
// If this is fp16 x fp16 ->fp16 case prioritize using v_dot.
16311668
auto aOpType = dotOp.getA().getType();
@@ -1641,7 +1678,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16411678
rewriter.replaceOp(dotOp, newD);
16421679
return success();
16431680
}
1644-
return failure();
1681+
return rewriter.notifyMatchFailure(
1682+
dotOp, "Unable to choose V_DOT instruction for dot operation.");
16451683
}
16461684

16471685
LogicalResult tryLegalizeFMA(DotOp dotOp, PatternRewriter &rewriter,
@@ -1687,7 +1725,10 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16871725
LogicalResult matchAndRewrite(DotOp dotOp,
16881726
PatternRewriter &rewriter) const override {
16891727
if (!isa<BlockedEncodingAttr>(dotOp.getD().getType().getEncoding()))
1690-
return failure();
1728+
return rewriter.notifyMatchFailure(
1729+
dotOp, "expected blocked encoding result tensor");
1730+
1731+
dotOp.emitRemark() << "Attempting to map dot operation to FMA intrinsic.";
16911732

16921733
DotElTypes dotTypes;
16931734
dotTypes.a = dotOp.getA().getType().getElementType();
@@ -1697,7 +1738,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
16971738

16981739
// Check that dot is not legalized already
16991740
if (isLegalFMAForm(dotOp, dotTypes)) {
1700-
return failure();
1741+
return rewriter.notifyMatchFailure(
1742+
dotOp, "Dot operation is already in FMA form.");
17011743
}
17021744

17031745
// TODO: enable this condition, when fp32 -> fp16 cast works correctly

0 commit comments

Comments
 (0)