@@ -708,23 +708,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
708708 result.types .push_back (outputType);
709709}
710710
711- // / This builder is called on single-parameter unary operators that have scale
712- // / relationship between their input and output, expressed by the
713- // / UnaryOpQuantizationAttr.
714- static void buildUnaryOpWithQuantInfo (OpBuilder &builder,
715- OperationState &result, Type outputType,
716- Value input) {
717- result.addOperands (input);
711+ // / This builder is called on single-parameter negate operator
712+ // / to construct input and output zero points based on their
713+ // / types.
714+ static void buildNegateOpWithQuantInfo (OpBuilder &builder,
715+ OperationState &result, Type outputType,
716+ Value input) {
717+ const Location loc{result.location };
718+ int64_t input1Zp{0 };
719+ int64_t outputZp{0 };
718720 auto quantAttr = buildUnaryOpQuantizationAttr (builder, input, outputType);
719721 if (quantAttr) {
720- // note: negateOp has attributes input1_zp and output_zp
721- result.addAttribute (" input1_zp" ,
722- builder.getI32IntegerAttr (
723- static_cast <int32_t >(quantAttr.getInputZp ())));
724- result.addAttribute (" output_zp" ,
725- builder.getI32IntegerAttr (
726- static_cast <int32_t >(quantAttr.getOutputZp ())));
722+ input1Zp = quantAttr.getInputZp ();
723+ outputZp = quantAttr.getOutputZp ();
724+ }
725+ const std::optional<Value> input1ZpOp =
726+ createZeroPointTensor (builder, loc, input.getType (), input1Zp);
727+ if (!input1ZpOp) {
728+ (void )emitError (
729+ loc, " Failed to create input1 zero point for quantized NEGATE op" );
730+ }
731+
732+ const std::optional<Value> outputZpOp =
733+ createZeroPointTensor (builder, loc, input.getType (), outputZp);
734+ if (!outputZpOp) {
735+ (void )emitError (
736+ loc, " Failed to create output zero point for quantized NEGATE op" );
727737 }
738+
739+ if (input1ZpOp && outputZpOp) {
740+ result.addOperands ({input, input1ZpOp.value (), outputZpOp.value ()});
741+ } else {
742+ // failed to create one or more zero points above: just add input as
743+ // operands. This will trigger error in building the op because of
744+ // missing zero points
745+ result.addOperands ({input});
746+ }
747+
728748 result.types .push_back (outputType);
729749}
730750
@@ -1714,6 +1734,9 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
17141734ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
17151735ZERO_POINT_HELPER(AvgPool2dOp, Input)
17161736ZERO_POINT_HELPER(AvgPool2dOp, Output)
1737+ ZERO_POINT_HELPER(NegateOp, Input1)
1738+ ZERO_POINT_HELPER(NegateOp, Output)
1739+
17171740#undef ZERO_POINT_HELPER
17181741
17191742LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
@@ -2216,7 +2239,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
22162239NARY_SHAPE_INFER(tosa::LogicalXorOp)
22172240NARY_SHAPE_INFER(tosa::MaximumOp)
22182241NARY_SHAPE_INFER(tosa::MinimumOp)
2219- NARY_SHAPE_INFER(tosa::NegateOp)
22202242NARY_SHAPE_INFER(tosa::PowOp)
22212243NARY_SHAPE_INFER(tosa::ReciprocalOp)
22222244NARY_SHAPE_INFER(tosa::ReverseOp)
@@ -2229,6 +2251,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
22292251NARY_SHAPE_INFER(tosa::SigmoidOp)
22302252#undef PRED_SHAPE_INFER
22312253
2254+ LogicalResult tosa::NegateOp::inferReturnTypeComponents (
2255+ MLIRContext *context, ::std::optional<Location> location,
2256+ NegateOp::Adaptor adaptor,
2257+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2258+ ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
2259+ inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
2260+ return success ();
2261+ }
2262+
2263+ LogicalResult tosa::NegateOp::verify () {
2264+ // Verify same element type
2265+ const Type input1Type = getInput1 ().getType ();
2266+ const Type outputType = getOutput ().getType ();
2267+ if (verifySameElementTypes (*this , input1Type, outputType).failed ())
2268+ return failure ();
2269+
2270+ // Verify same shape
2271+ const SmallVector<Type, 2 > types = {input1Type, outputType};
2272+ if (failed (verifyCompatibleShapes (types)))
2273+ return emitOpError () << " requires the same shape for input1 and output" ;
2274+
2275+ const Type input1EType = getStorageElementTypeOrSelf (getInput1 ().getType ());
2276+ const Type input1ZpEType =
2277+ getStorageElementTypeOrSelf (getInput1Zp ().getType ());
2278+ if (input1EType != input1ZpEType) {
2279+ return emitOpError (" expect both input1 and its zero point are the same "
2280+ " element type, got " )
2281+ << input1EType << " and " << input1ZpEType;
2282+ }
2283+ const Type outputEType = getStorageElementTypeOrSelf (getOutput ().getType ());
2284+ const Type outputZpEType =
2285+ getStorageElementTypeOrSelf (getOutputZp ().getType ());
2286+ if (outputEType != outputZpEType) {
2287+ return emitOpError (" expect both output and its zero point are the same "
2288+ " element type, got " )
2289+ << outputEType << " and " << outputZpEType;
2290+ }
2291+
2292+ FailureOr<int64_t > maybeIZp = getInput1ZeroPoint ();
2293+ if (succeeded (maybeIZp) && verifyInput1ZeroPoint (*maybeIZp).failed ())
2294+ return failure ();
2295+
2296+ FailureOr<int64_t > maybeOZp = getOutputZeroPoint ();
2297+ if (succeeded (maybeOZp) && verifyOutputZeroPoint (*maybeOZp).failed ())
2298+ return failure ();
2299+
2300+ return success ();
2301+ }
2302+
22322303static LogicalResult poolingInferReturnTypes (
22332304 ShapeAdaptor inputShape, ArrayRef<int64_t > kernel, ArrayRef<int64_t > stride,
22342305 ArrayRef<int64_t > pad,
0 commit comments