@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
240240 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241241 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242242
243- if (auto quantType =
244- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
243+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
245244 inputEType = quantType.getStorageType ();
246245
247- if (auto quantType =
248- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
246+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
249247 biasEType = quantType.getStorageType ();
250248
251- if (auto quantType =
252- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
249+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
253250 resultEType = quantType.getStorageType ();
254251
255252 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
346343 auto inputEType =
347344 llvm::cast<ShapedType>(op.getInput ().getType ()).getElementType ();
348345
349- if (auto quantType =
350- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
346+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
351347 inputEType = quantType.getStorageType ();
352348
353349 auto accType = op.getAccType ();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
369365 if (inputEType.isF32 () && !accType.isF32 ())
370366 return op.emitOpError (" accumulator type for f32 tensor is not f32" );
371367
372- return success ();
368+ auto resultEType =
369+ llvm::cast<ShapedType>(op.getResult ().getType ()).getElementType ();
370+
371+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372+ resultEType = quantType.getStorageType ();
373+
374+ // check allowed input/result element types combinations
375+ if ((inputEType.isInteger (8 ) && resultEType.isInteger (32 )) ||
376+ (inputEType.isInteger (16 ) && resultEType.isInteger (48 )) ||
377+ (isa<Float8E5M2Type>(inputEType) && resultEType.isF16 ()) ||
378+ (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16 ()) ||
379+ (inputEType.isF16 () && resultEType.isF16 ()) ||
380+ (inputEType.isBF16 () && resultEType.isBF16 ()) ||
381+ (inputEType.isF32 () && resultEType.isF32 ()))
382+ return success ();
383+
384+ return op.emitOpError (" input/output element types are incompatible." );
373385}
374386
375387// verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
519531 OpBuilder &builder, OperationState &result, Type outputType, Value input,
520532 Value weight, Value bias, DenseI64ArrayAttr outpad,
521533 DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
522- result.addOperands ({input, weight, bias});
534+ auto zps = createZPsAsConst (builder, input, weight);
535+ result.addOperands ({input, weight, bias, zps.first , zps.second });
523536 result.addAttribute (" out_pad" , outpad);
524537 result.addAttribute (" stride" , stride);
525538 result.addAttribute (" out_shape" , outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
24782491 return failure ();
24792492}
24802493
2481- // Create a rank-0 const tensor for zero point of the source tensor.
2494+ // Create a rank-1 const tensor for zero point of the source tensor.
24822495std::optional<Value> mlir::tosa::createZeroPointTensor (OpBuilder &builder,
24832496 Location loc,
24842497 Type srcElemType,
24852498 int64_t zp) {
2486- if (auto quantType =
2487- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
2488- srcElemType = quantType.getStorageType ();
2489-
2490- auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
2499+ srcElemType = getElementTypeOrSelf (srcElemType);
24912500 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
24922501 srcElemType = quantType.getStorageType ();
2502+ auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
24932503 if (llvm::isa<FloatType>(srcElemType)) {
24942504 auto zpAttr = DenseElementsAttr::get (
24952505 zpType, builder.getFloatAttr (srcElemType, static_cast <double >(zp)));
0 commit comments