@@ -697,23 +697,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
697697 result.types .push_back (outputType);
698698}
699699
700- // / This builder is called on single-parameter unary operators that have scale
701- // / relationship between their input and output, expressed by the
702- // / UnaryOpQuantizationAttr.
703- static void buildUnaryOpWithQuantInfo (OpBuilder &builder,
704- OperationState &result, Type outputType,
705- Value input) {
706- result.addOperands (input);
700+ // / This builder is called on single-parameter negate operator
701+ // / to construct input and output zero points based on their
702+ // / types.
703+ static void buildNegateOpWithQuantInfo (OpBuilder &builder,
704+ OperationState &result, Type outputType,
705+ Value input) {
706+ const Location loc{result.location };
707+ int64_t input1Zp{0 };
708+ int64_t outputZp{0 };
707709 auto quantAttr = buildUnaryOpQuantizationAttr (builder, input, outputType);
708710 if (quantAttr) {
709- // note: negateOp has attributes input1_zp and output_zp
710- result.addAttribute (" input1_zp" ,
711- builder.getI32IntegerAttr (
712- static_cast <int32_t >(quantAttr.getInputZp ())));
713- result.addAttribute (" output_zp" ,
714- builder.getI32IntegerAttr (
715- static_cast <int32_t >(quantAttr.getOutputZp ())));
711+ input1Zp = quantAttr.getInputZp ();
712+ outputZp = quantAttr.getOutputZp ();
713+ }
714+ const std::optional<Value> input1ZpOp =
715+ createZeroPointTensor (builder, loc, input.getType (), input1Zp);
716+ if (!input1ZpOp) {
717+ (void )emitError (
718+ loc, " Failed to create input1 zero point for quantized NEGATE op" );
719+ }
720+
721+ const std::optional<Value> outputZpOp =
722+ createZeroPointTensor (builder, loc, input.getType (), outputZp);
723+ if (!outputZpOp) {
724+ (void )emitError (
725+ loc, " Failed to create output zero point for quantized NEGATE op" );
716726 }
727+
728+ if (input1ZpOp && outputZpOp) {
729+ result.addOperands ({input, input1ZpOp.value (), outputZpOp.value ()});
730+ } else {
731+ // failed to create one or more zero points above: just add input as
732+ // operands. This will trigger error in building the op because of
733+ // missing zero points
734+ result.addOperands ({input});
735+ }
736+
717737 result.types .push_back (outputType);
718738}
719739
@@ -1729,6 +1749,9 @@ ZERO_POINT_HELPER(AvgPool2dOp, Input)
17291749ZERO_POINT_HELPER(AvgPool2dOp, Output)
17301750ZERO_POINT_HELPER(MatMulOp, A)
17311751ZERO_POINT_HELPER(MatMulOp, B)
1752+ ZERO_POINT_HELPER(NegateOp, Input1)
1753+ ZERO_POINT_HELPER(NegateOp, Output)
1754+
17321755#undef ZERO_POINT_HELPER
17331756
17341757LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
@@ -2231,7 +2254,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
22312254NARY_SHAPE_INFER(tosa::LogicalXorOp)
22322255NARY_SHAPE_INFER(tosa::MaximumOp)
22332256NARY_SHAPE_INFER(tosa::MinimumOp)
2234- NARY_SHAPE_INFER(tosa::NegateOp)
22352257NARY_SHAPE_INFER(tosa::PowOp)
22362258NARY_SHAPE_INFER(tosa::ReciprocalOp)
22372259NARY_SHAPE_INFER(tosa::ReverseOp)
@@ -2244,6 +2266,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
22442266NARY_SHAPE_INFER(tosa::SigmoidOp)
22452267#undef PRED_SHAPE_INFER
22462268
2269+ LogicalResult tosa::NegateOp::inferReturnTypeComponents (
2270+ MLIRContext *context, ::std::optional<Location> location,
2271+ NegateOp::Adaptor adaptor,
2272+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2273+ ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
2274+ inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
2275+ return success ();
2276+ }
2277+
2278+ LogicalResult tosa::NegateOp::verify () {
2279+ // Verify same element type
2280+ const Type input1Type = getInput1 ().getType ();
2281+ const Type outputType = getOutput ().getType ();
2282+ if (verifySameElementTypes (*this , input1Type, outputType).failed ())
2283+ return failure ();
2284+
2285+ // Verify same shape
2286+ const SmallVector<Type, 2 > types = {input1Type, outputType};
2287+ if (failed (verifyCompatibleShapes (types)))
2288+ return emitOpError () << " requires the same shape for input1 and output" ;
2289+
2290+ const Type input1EType = getStorageElementTypeOrSelf (getInput1 ().getType ());
2291+ const Type input1ZpEType =
2292+ getStorageElementTypeOrSelf (getInput1Zp ().getType ());
2293+ if (input1EType != input1ZpEType) {
2294+ return emitOpError (" expect both input1 and its zero point are the same "
2295+ " element type, got " )
2296+ << input1EType << " and " << input1ZpEType;
2297+ }
2298+ const Type outputEType = getStorageElementTypeOrSelf (getOutput ().getType ());
2299+ const Type outputZpEType =
2300+ getStorageElementTypeOrSelf (getOutputZp ().getType ());
2301+ if (outputEType != outputZpEType) {
2302+ return emitOpError (" expect both output and its zero point are the same "
2303+ " element type, got " )
2304+ << outputEType << " and " << outputZpEType;
2305+ }
2306+
2307+ FailureOr<int64_t > maybeIZp = getInput1ZeroPoint ();
2308+ if (succeeded (maybeIZp) && verifyInput1ZeroPoint (*maybeIZp).failed ())
2309+ return failure ();
2310+
2311+ FailureOr<int64_t > maybeOZp = getOutputZeroPoint ();
2312+ if (succeeded (maybeOZp) && verifyOutputZeroPoint (*maybeOZp).failed ())
2313+ return failure ();
2314+
2315+ return success ();
2316+ }
2317+
22472318static LogicalResult poolingInferReturnTypes (
22482319 ShapeAdaptor inputShape, ArrayRef<int64_t > kernel, ArrayRef<int64_t > stride,
22492320 ArrayRef<int64_t > pad,
0 commit comments