@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
240240 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241241 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242242
243- if (auto quantType =
244- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
243+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
245244 inputEType = quantType.getStorageType ();
246245
247- if (auto quantType =
248- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
246+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
249247 biasEType = quantType.getStorageType ();
250248
251- if (auto quantType =
252- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
249+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
253250 resultEType = quantType.getStorageType ();
254251
255252 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
346343 auto inputEType =
347344 llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
348345
349- if (auto quantType =
350- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
346+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
351347 inputEType = quantType.getStorageType ();
352348
353349 auto accType = op.getAccType ();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
369365 if (inputEType.isF32 () && !accType.isF32 ())
370366 return op.emitOpError (" accumulator type for f32 tensor is not f32" );
371367
372- return success ();
368+ auto resultEType =
369+ llvm::cast<ShapedType>(op.getResult ().getType ()).getElementType ();
370+
371+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372+ resultEType = quantType.getStorageType ();
373+
374+ // check allowed input/result element types combinations
375+ if ((inputEType.isInteger (8 ) && resultEType.isInteger (32 )) ||
376+ (inputEType.isInteger (16 ) && resultEType.isInteger (48 )) ||
377+ (isa<Float8E5M2Type>(inputEType) && resultEType.isF16 ()) ||
378+ (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16 ()) ||
379+ (inputEType.isF16 () && resultEType.isF16 ()) ||
380+ (inputEType.isBF16 () && resultEType.isBF16 ()) ||
381+ (inputEType.isF32 () && resultEType.isF32 ()))
382+ return success ();
383+
384+ return op.emitOpError (" input/output element types are incompatible." );
373385}
374386
375387// verify that inType and outType have same element types
0 commit comments