@@ -136,9 +136,9 @@ warpsPerTileWMMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps) {
136
136
// If enforcedNonKDim is not zero, it will be used to overwrite the default
137
137
// logic to choose a MFMA with matching M/N dim.
138
138
FailureOr<MfmaIntrinsic>
139
- chooseMfmaInstruction (int mfmaVersion, RankedTensorType cType, Type aElemType ,
140
- Type bElemType, int inputKSize , int enforcedNonKDim ,
141
- bool withScale, bool allowXF32) {
139
+ chooseMfmaInstruction (Location loc, int mfmaVersion, RankedTensorType cType,
140
+ Type aElemType, Type bElemType , int inputKSize ,
141
+ int enforcedNonKDim, bool withScale, bool allowXF32) {
142
142
// number of matrix elements along k dim per one MFMA instruction
143
143
unsigned kDim = 0 ;
144
144
@@ -169,7 +169,8 @@ chooseMfmaInstruction(int mfmaVersion, RankedTensorType cType, Type aElemType,
169
169
MfmaIntrinsic::selectFor (mfmaVersion, mDim , nDim, inputKSize, aElemType,
170
170
bElemType, withScale, allowXF32);
171
171
if (failed (maybeMfmaIntrinsic))
172
- llvm::report_fatal_error (" No match found in MFMA database\n " );
172
+ return emitError (loc, " no matching matrix core intrinsic due to "
173
+ " unsupported element type" );
173
174
174
175
kDim = maybeMfmaIntrinsic->kDim ;
175
176
assert (kDim != 0 );
@@ -188,7 +189,7 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
188
189
bool allowXF32 =
189
190
dot.getInputPrecision () == InputPrecision::TF32 && mfmaVersion == 3 ;
190
191
return chooseMfmaInstruction (
191
- mfmaVersion, dot.getC ().getType (), aType.getElementType (),
192
+ dot. getLoc (), mfmaVersion, dot.getC ().getType (), aType.getElementType (),
192
193
dot.getB ().getType ().getElementType (), aType.getShape ().back (), nonKDim,
193
194
withScale, allowXF32);
194
195
}
@@ -204,8 +205,8 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
204
205
}
205
206
Type aElemType = scaleDotElemTypeToMLIRType (ctx, dot.getAElemType ());
206
207
Type bElemType = scaleDotElemTypeToMLIRType (ctx, dot.getBElemType ());
207
- return chooseMfmaInstruction (mfmaVersion, dot.getC ().getType (), aElemType ,
208
- bElemType, inputKDim, nonKDim,
208
+ return chooseMfmaInstruction (dot. getLoc (), mfmaVersion, dot.getC ().getType (),
209
+ aElemType, bElemType, inputKDim, nonKDim,
209
210
/* withScale=*/ true , /* allowXF32=*/ false );
210
211
}
211
212
@@ -215,9 +216,9 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
215
216
// For scaled dot, we handle it with fp16 or bf16 emulation for now.
216
217
Builder b (dot.getContext ());
217
218
Type elemType = useFp16 ? b.getF16Type () : b.getBF16Type ();
218
- return chooseMfmaInstruction (mfmaVersion, dot.getC ().getType (), elemType ,
219
- elemType, dot. getA (). getType (). getShape (). back () ,
220
- nonKDim,
219
+ return chooseMfmaInstruction (dot. getLoc (), mfmaVersion, dot.getC ().getType (),
220
+ elemType, elemType ,
221
+ dot. getA (). getType (). getShape (). back (), nonKDim,
221
222
/* withScale=*/ false , /* allowXF32=*/ false );
222
223
}
223
224
0 commit comments