Skip to content

Commit 19c842c

Browse files
authored
[AMD] Use a proper error message for no matching MFMA intrinsic (#7185)
If we cannot match a proper MFMA intrinsic, it pretty much is due to that the element type is not supported. So make the error message more explanatory instead of using fatal error with a vague message.
1 parent 993c8da commit 19c842c

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" --verify-diagnostics
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
4+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
5+
tt.func public @mfma_dot_fp8e4m3fn(
6+
%arg0: tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
7+
%arg1: tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
8+
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked>) {
9+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
10+
// expected-error @+1 {{no matching matrix core intrinsic due to unsupported element type}}
11+
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
12+
tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
13+
tt.return
14+
}
15+
}

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ struct DotOpMFMAConversionHelper {
275275
mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB,
276276
/*withScale=*/false, allowXF32);
277277
if (failed(maybeMfmaIntrinsic))
278-
llvm::report_fatal_error("No match found in MFMA database\n");
278+
return op.emitError(
279+
"no matching matrix core intrinsic due to unsupported element type");
279280

280281
unsigned kBase = maybeMfmaIntrinsic->kBase;
281282

@@ -590,7 +591,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
590591
scaleDotElemTypeToMLIRType(ctx, bElemType),
591592
/*withScale=*/true, allowXF32);
592593
if (failed(maybeMfmaIntrinsic))
593-
llvm::report_fatal_error("No match found in MFMA database\n");
594+
return op.emitError(
595+
"no matching matrix core intrinsic due to unsupported element type");
594596

595597
StringRef intrinsicName = maybeMfmaIntrinsic->name;
596598
unsigned kBase = maybeMfmaIntrinsic->kBase;

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ warpsPerTileWMMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps) {
136136
// If enforcedNonKDim is not zero, it will be used to overwrite the default
137137
// logic to choose a MFMA with matching M/N dim.
138138
FailureOr<MfmaIntrinsic>
139-
chooseMfmaInstruction(int mfmaVersion, RankedTensorType cType, Type aElemType,
140-
Type bElemType, int inputKSize, int enforcedNonKDim,
141-
bool withScale, bool allowXF32) {
139+
chooseMfmaInstruction(Location loc, int mfmaVersion, RankedTensorType cType,
140+
Type aElemType, Type bElemType, int inputKSize,
141+
int enforcedNonKDim, bool withScale, bool allowXF32) {
142142
// number of matrix elements along k dim per one MFMA instruction
143143
unsigned kDim = 0;
144144

@@ -169,7 +169,8 @@ chooseMfmaInstruction(int mfmaVersion, RankedTensorType cType, Type aElemType,
169169
MfmaIntrinsic::selectFor(mfmaVersion, mDim, nDim, inputKSize, aElemType,
170170
bElemType, withScale, allowXF32);
171171
if (failed(maybeMfmaIntrinsic))
172-
llvm::report_fatal_error("No match found in MFMA database\n");
172+
return emitError(loc, "no matching matrix core intrinsic due to "
173+
"unsupported element type");
173174

174175
kDim = maybeMfmaIntrinsic->kDim;
175176
assert(kDim != 0);
@@ -188,7 +189,7 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
188189
bool allowXF32 =
189190
dot.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
190191
return chooseMfmaInstruction(
191-
mfmaVersion, dot.getC().getType(), aType.getElementType(),
192+
dot.getLoc(), mfmaVersion, dot.getC().getType(), aType.getElementType(),
192193
dot.getB().getType().getElementType(), aType.getShape().back(), nonKDim,
193194
withScale, allowXF32);
194195
}
@@ -204,8 +205,8 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
204205
}
205206
Type aElemType = scaleDotElemTypeToMLIRType(ctx, dot.getAElemType());
206207
Type bElemType = scaleDotElemTypeToMLIRType(ctx, dot.getBElemType());
207-
return chooseMfmaInstruction(mfmaVersion, dot.getC().getType(), aElemType,
208-
bElemType, inputKDim, nonKDim,
208+
return chooseMfmaInstruction(dot.getLoc(), mfmaVersion, dot.getC().getType(),
209+
aElemType, bElemType, inputKDim, nonKDim,
209210
/*withScale=*/true, /*allowXF32=*/false);
210211
}
211212

@@ -215,9 +216,9 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
215216
// For scaled dot, we handle it with fp16 or bf16 emulation for now.
216217
Builder b(dot.getContext());
217218
Type elemType = useFp16 ? b.getF16Type() : b.getBF16Type();
218-
return chooseMfmaInstruction(mfmaVersion, dot.getC().getType(), elemType,
219-
elemType, dot.getA().getType().getShape().back(),
220-
nonKDim,
219+
return chooseMfmaInstruction(dot.getLoc(), mfmaVersion, dot.getC().getType(),
220+
elemType, elemType,
221+
dot.getA().getType().getShape().back(), nonKDim,
221222
/*withScale=*/false, /*allowXF32=*/false);
222223
}
223224

0 commit comments

Comments
 (0)