diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 955021abdd67b..fd166cc1322ce 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -510,7 +510,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); result.addAttribute("acc_type", accType); - result.addTypes(outputType); + Type finalOutputType = outputType; + auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); + if (quantAttr) { + finalOutputType = + buildConvOpResultTypeInfo(builder, outputType, input, weight); + } + result.addTypes(finalOutputType); } /// Handles tosa.transpose_conv2d which has outpad and output shape @@ -519,25 +525,19 @@ static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) { - result.addOperands({input, weight, bias}); + auto zps = createZPsAsConst(builder, input, weight); + result.addOperands({input, weight, bias, zps.first, zps.second}); result.addAttribute("out_pad", outpad); result.addAttribute("stride", stride); result.addAttribute("out_shape", outputShape); result.addAttribute("acc_type", accType); - auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); - + Type finalOutputType = outputType; + auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { - result.addAttribute("input_zp", - builder.getI32IntegerAttr( - static_cast(quantAttr.getInputZp()))); - result.addAttribute("weight_zp", - builder.getI32IntegerAttr( - static_cast(quantAttr.getWeightZp()))); - result.addTypes( - buildConvOpResultTypeInfo(builder, outputType, input, weight)); - } else { - result.addTypes(outputType); + finalOutputType = + buildConvOpResultTypeInfo(builder, outputType, input, weight); } + result.addTypes(finalOutputType); } /// The tosa.fully_connected op has its own builder as it does not have @@ -2492,18 +2492,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) { return failure(); } -// Create a rank-0 const tensor for zero point of the source tensor. +// Create a rank-1 const tensor for zero point of the source tensor. std::optional mlir::tosa::createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp) { - if (auto quantType = - llvm::dyn_cast(srcElemType)) - srcElemType = quantType.getStorageType(); - - auto zpType = mlir::RankedTensorType::get({1}, srcElemType); + srcElemType = getElementTypeOrSelf(srcElemType); if (auto quantType = llvm::dyn_cast(srcElemType)) srcElemType = quantType.getStorageType(); + auto zpType = mlir::RankedTensorType::get({1}, srcElemType); if (llvm::isa(srcElemType)) { auto zpAttr = DenseElementsAttr::get( zpType, builder.getFloatAttr(srcElemType, static_cast(zp)));