@@ -136,9 +136,9 @@ warpsPerTileWMMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps) {
136136// If enforcedNonKDim is not zero, it will be used to overwrite the default
137137// logic to choose a MFMA with matching M/N dim.
138138FailureOr<MfmaIntrinsic>
139- chooseMfmaInstruction (int mfmaVersion, RankedTensorType cType, Type aElemType ,
140- Type bElemType, int inputKSize , int enforcedNonKDim ,
141- bool withScale, bool allowXF32) {
139+ chooseMfmaInstruction (Location loc, int mfmaVersion, RankedTensorType cType,
140+ Type aElemType, Type bElemType , int inputKSize ,
141+ int enforcedNonKDim, bool withScale, bool allowXF32) {
142142 // number of matrix elements along k dim per one MFMA instruction
143143 unsigned kDim = 0 ;
144144
@@ -169,7 +169,8 @@ chooseMfmaInstruction(int mfmaVersion, RankedTensorType cType, Type aElemType,
169169 MfmaIntrinsic::selectFor (mfmaVersion, mDim , nDim, inputKSize, aElemType,
170170 bElemType, withScale, allowXF32);
171171 if (failed (maybeMfmaIntrinsic))
172- llvm::report_fatal_error (" No match found in MFMA database\n " );
172+ return emitError (loc, " no matching matrix core intrinsic due to "
173+ " unsupported element type" );
173174
174175 kDim = maybeMfmaIntrinsic->kDim ;
175176 assert (kDim != 0 );
@@ -188,7 +189,7 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
188189 bool allowXF32 =
189190 dot.getInputPrecision () == InputPrecision::TF32 && mfmaVersion == 3 ;
190191 return chooseMfmaInstruction (
191- mfmaVersion, dot.getC ().getType (), aType.getElementType (),
192+ dot. getLoc (), mfmaVersion, dot.getC ().getType (), aType.getElementType (),
192193 dot.getB ().getType ().getElementType (), aType.getShape ().back (), nonKDim,
193194 withScale, allowXF32);
194195}
@@ -204,8 +205,8 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
204205 }
205206 Type aElemType = scaleDotElemTypeToMLIRType (ctx, dot.getAElemType ());
206207 Type bElemType = scaleDotElemTypeToMLIRType (ctx, dot.getBElemType ());
207- return chooseMfmaInstruction (mfmaVersion, dot.getC ().getType (), aElemType ,
208- bElemType, inputKDim, nonKDim,
208+ return chooseMfmaInstruction (dot. getLoc (), mfmaVersion, dot.getC ().getType (),
209+ aElemType, bElemType, inputKDim, nonKDim,
209210 /* withScale=*/ true , /* allowXF32=*/ false );
210211}
211212
@@ -215,9 +216,9 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
215216 // For scaled dot, we handle it with fp16 or bf16 emulation for now.
216217 Builder b (dot.getContext ());
217218 Type elemType = useFp16 ? b.getF16Type () : b.getBF16Type ();
218- return chooseMfmaInstruction (mfmaVersion, dot.getC ().getType (), elemType ,
219- elemType, dot. getA (). getType (). getShape (). back () ,
220- nonKDim,
219+ return chooseMfmaInstruction (dot. getLoc (), mfmaVersion, dot.getC ().getType (),
220+ elemType, elemType ,
221+ dot. getA (). getType (). getShape (). back (), nonKDim,
221222 /* withScale=*/ false , /* allowXF32=*/ false );
222223}
223224
0 commit comments