@@ -715,23 +715,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
715715 result.types .push_back (outputType);
716716}
717717
718- // / This builder is called on single-parameter unary operators that have scale
719- // / relationship between their input and output, expressed by the
720- // / UnaryOpQuantizationAttr.
721- static void buildUnaryOpWithQuantInfo (OpBuilder &builder,
722- OperationState &result, Type outputType,
723- Value input) {
724- result.addOperands (input);
718+ // / This builder is called on single-parameter negate operator
719+ // / to construct input and output zero points based on their
720+ // / types.
721+ static void buildNegateOpWithQuantInfo (OpBuilder &builder,
722+ OperationState &result, Type outputType,
723+ Value input) {
724+ const Location loc{result.location };
725+ int64_t input1Zp{0 };
726+ int64_t outputZp{0 };
725727 auto quantAttr = buildUnaryOpQuantizationAttr (builder, input, outputType);
726728 if (quantAttr) {
727- // note: negateOp has attributes input1_zp and output_zp
728- result.addAttribute (" input1_zp" ,
729- builder.getI32IntegerAttr (
730- static_cast <int32_t >(quantAttr.getInputZp ())));
731- result.addAttribute (" output_zp" ,
732- builder.getI32IntegerAttr (
733- static_cast <int32_t >(quantAttr.getOutputZp ())));
729+ input1Zp = quantAttr.getInputZp ();
730+ outputZp = quantAttr.getOutputZp ();
731+ }
732+ const std::optional<Value> input1ZpOp =
733+ createZeroPointTensor (builder, loc, input.getType (), input1Zp);
734+ if (!input1ZpOp) {
735+ (void )emitError (
736+ loc, " Failed to create input1 zero point for quantized NEGATE op" );
737+ }
738+
739+ const std::optional<Value> outputZpOp =
740+ createZeroPointTensor (builder, loc, input.getType (), outputZp);
741+ if (!outputZpOp) {
742+ (void )emitError (
743+ loc, " Failed to create output zero point for quantized NEGATE op" );
734744 }
745+
746+ if (input1ZpOp && outputZpOp) {
747+ result.addOperands ({input, input1ZpOp.value (), outputZpOp.value ()});
748+ } else {
749+ // failed to create one or more zero points above: just add input as
750+ // operands. This will trigger error in building the op because of
751+ // missing zero points
752+ result.addOperands ({input});
753+ }
754+
735755 result.types .push_back (outputType);
736756}
737757
@@ -1721,6 +1741,9 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
17211741ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
17221742ZERO_POINT_HELPER(AvgPool2dOp, Input)
17231743ZERO_POINT_HELPER(AvgPool2dOp, Output)
1744+ ZERO_POINT_HELPER(NegateOp, Input1)
1745+ ZERO_POINT_HELPER(NegateOp, Output)
1746+
17241747#undef ZERO_POINT_HELPER
17251748
17261749LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
@@ -2222,7 +2245,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
22222245NARY_SHAPE_INFER(tosa::LogicalXorOp)
22232246NARY_SHAPE_INFER(tosa::MaximumOp)
22242247NARY_SHAPE_INFER(tosa::MinimumOp)
2225- NARY_SHAPE_INFER(tosa::NegateOp)
22262248NARY_SHAPE_INFER(tosa::PowOp)
22272249NARY_SHAPE_INFER(tosa::ReciprocalOp)
22282250NARY_SHAPE_INFER(tosa::RescaleOp)
@@ -2236,6 +2258,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
22362258NARY_SHAPE_INFER(tosa::SigmoidOp)
22372259#undef PRED_SHAPE_INFER
22382260
2261+ LogicalResult tosa::NegateOp::inferReturnTypeComponents (
2262+ MLIRContext *context, ::std::optional<Location> location,
2263+ NegateOp::Adaptor adaptor,
2264+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2265+ ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
2266+ inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
2267+ return success ();
2268+ }
2269+
2270+ LogicalResult tosa::NegateOp::verify () {
2271+ // Verify same element type
2272+ const Type input1Type = getInput1 ().getType ();
2273+ const Type outputType = getOutput ().getType ();
2274+ if (verifySameElementTypes (*this , input1Type, outputType).failed ())
2275+ return failure ();
2276+
2277+ // Verify same shape
2278+ const SmallVector<Type, 2 > types = {input1Type, outputType};
2279+ if (failed (verifyCompatibleShapes (types)))
2280+ return emitOpError () << " requires the same shape for input1 and output" ;
2281+
2282+ const Type input1EType = getStorageElementTypeOrSelf (getInput1 ().getType ());
2283+ const Type input1ZpEType =
2284+ getStorageElementTypeOrSelf (getInput1Zp ().getType ());
2285+ if (input1EType != input1ZpEType) {
2286+ return emitOpError (" expect both input1 and its zero point are the same "
2287+ " element type, got " )
2288+ << input1EType << " and " << input1ZpEType;
2289+ }
2290+ const Type outputEType = getStorageElementTypeOrSelf (getOutput ().getType ());
2291+ const Type outputZpEType =
2292+ getStorageElementTypeOrSelf (getOutputZp ().getType ());
2293+ if (outputEType != outputZpEType) {
2294+ return emitOpError (" expect both output and its zero point are the same "
2295+ " element type, got " )
2296+ << outputEType << " and " << outputZpEType;
2297+ }
2298+
2299+ FailureOr<int64_t > maybeIZp = getInput1ZeroPoint ();
2300+ if (succeeded (maybeIZp) && verifyInput1ZeroPoint (*maybeIZp).failed ())
2301+ return failure ();
2302+
2303+ FailureOr<int64_t > maybeOZp = getOutputZeroPoint ();
2304+ if (succeeded (maybeOZp) && verifyOutputZeroPoint (*maybeOZp).failed ())
2305+ return failure ();
2306+
2307+ return success ();
2308+ }
2309+
22392310static LogicalResult poolingInferReturnTypes (
22402311 ShapeAdaptor inputShape, ArrayRef<int64_t > kernel, ArrayRef<int64_t > stride,
22412312 ArrayRef<int64_t > pad,
0 commit comments