@@ -510,7 +510,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
510510 result.addAttribute (" stride" , stride);
511511 result.addAttribute (" dilation" , dilation);
512512 result.addAttribute (" acc_type" , accType);
513- result.addTypes (outputType);
513+ Type finalOutputType = outputType;
514+ auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
515+ if (quantAttr) {
516+ finalOutputType =
517+ buildConvOpResultTypeInfo (builder, outputType, input, weight);
518+ }
519+ result.addTypes (finalOutputType);
514520}
515521
516522// / Handles tosa.transpose_conv2d which has outpad and output shape
@@ -519,25 +525,19 @@ static void buildTransConvOpWithQuantInfo(
519525 OpBuilder &builder, OperationState &result, Type outputType, Value input,
520526 Value weight, Value bias, DenseI64ArrayAttr outpad,
521527 DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
522- result.addOperands ({input, weight, bias});
528+ auto zps = createZPsAsConst (builder, input, weight);
529+ result.addOperands ({input, weight, bias, zps.first , zps.second });
523530 result.addAttribute (" out_pad" , outpad);
524531 result.addAttribute (" stride" , stride);
525532 result.addAttribute (" out_shape" , outputShape);
526533 result.addAttribute (" acc_type" , accType);
527- auto quantAttr = :: buildConvOpQuantizationAttr (builder, input, weight) ;
528-
534+ Type finalOutputType = outputType ;
535+ auto quantAttr = buildConvOpQuantizationAttr (builder, input, weight);
529536 if (quantAttr) {
530- result.addAttribute (" input_zp" ,
531- builder.getI32IntegerAttr (
532- static_cast <int32_t >(quantAttr.getInputZp ())));
533- result.addAttribute (" weight_zp" ,
534- builder.getI32IntegerAttr (
535- static_cast <int32_t >(quantAttr.getWeightZp ())));
536- result.addTypes (
537- buildConvOpResultTypeInfo (builder, outputType, input, weight));
538- } else {
539- result.addTypes (outputType);
537+ finalOutputType =
538+ buildConvOpResultTypeInfo (builder, outputType, input, weight);
540539 }
540+ result.addTypes (finalOutputType);
541541}
542542
543543// / The tosa.fully_connected op has its own builder as it does not have
@@ -2492,18 +2492,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
24922492 return failure ();
24932493}
24942494
2495- // Create a rank-0 const tensor for zero point of the source tensor.
2495+ // Create a rank-1 const tensor for zero point of the source tensor.
24962496std::optional<Value> mlir::tosa::createZeroPointTensor (OpBuilder &builder,
24972497 Location loc,
24982498 Type srcElemType,
24992499 int64_t zp) {
2500- if (auto quantType =
2501- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
2502- srcElemType = quantType.getStorageType ();
2503-
2504- auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
2500+ srcElemType = getElementTypeOrSelf (srcElemType);
25052501 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
25062502 srcElemType = quantType.getStorageType ();
2503+ auto zpType = mlir::RankedTensorType::get ({1 }, srcElemType);
25072504 if (llvm::isa<FloatType>(srcElemType)) {
25082505 auto zpAttr = DenseElementsAttr::get (
25092506 zpType, builder.getFloatAttr (srcElemType, static_cast <double >(zp)));
0 commit comments