@@ -698,23 +698,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
698698 result.types .push_back (outputType);
699699}
700700
701- // / This builder is called on single-parameter unary operators that have scale
702- // / relationship between their input and output, expressed by the
703- // / UnaryOpQuantizationAttr.
704- static void buildUnaryOpWithQuantInfo (OpBuilder &builder,
705- OperationState &result, Type outputType,
706- Value input) {
707- result.addOperands (input);
701+ // / This builder is called on single-parameter negate operator
702+ // / to construct input and output zero points based on their
703+ // / types.
704+ static void buildNegateOpWithQuantInfo (OpBuilder &builder,
705+ OperationState &result, Type outputType,
706+ Value input) {
707+ const Location loc{result.location };
708+ int64_t input1Zp{0 };
709+ int64_t outputZp{0 };
708710 auto quantAttr = buildUnaryOpQuantizationAttr (builder, input, outputType);
709711 if (quantAttr) {
710- // note: negateOp has attributes input1_zp and output_zp
711- result.addAttribute (" input1_zp" ,
712- builder.getI32IntegerAttr (
713- static_cast <int32_t >(quantAttr.getInputZp ())));
714- result.addAttribute (" output_zp" ,
715- builder.getI32IntegerAttr (
716- static_cast <int32_t >(quantAttr.getOutputZp ())));
712+ input1Zp = quantAttr.getInputZp ();
713+ outputZp = quantAttr.getOutputZp ();
714+ }
715+ const std::optional<Value> input1ZpOp =
716+ createZeroPointTensor (builder, loc, input.getType (), input1Zp);
717+ if (!input1ZpOp) {
718+ (void )emitError (
719+ loc, " Failed to create input1 zero point for quantized NEGATE op" );
720+ }
721+
722+ const std::optional<Value> outputZpOp =
723+ createZeroPointTensor (builder, loc, input.getType (), outputZp);
724+ if (!outputZpOp) {
725+ (void )emitError (
726+ loc, " Failed to create output zero point for quantized NEGATE op" );
717727 }
728+
729+ if (input1ZpOp && outputZpOp) {
730+ result.addOperands ({input, input1ZpOp.value (), outputZpOp.value ()});
731+ } else {
732+ // failed to create one or more zero points above: just add input as
733+ // operands. This will trigger error in building the op because of
734+ // missing zero points
735+ result.addOperands ({input});
736+ }
737+
718738 result.types .push_back (outputType);
719739}
720740
@@ -1562,6 +1582,9 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
15621582ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
15631583ZERO_POINT_HELPER(AvgPool2dOp, Input)
15641584ZERO_POINT_HELPER(AvgPool2dOp, Output)
1585+ ZERO_POINT_HELPER(NegateOp, Input1)
1586+ ZERO_POINT_HELPER(NegateOp, Output)
1587+
15651588#undef ZERO_POINT_HELPER
15661589
15671590LogicalResult tosa::TransposeOp::inferReturnTypeComponents (
@@ -2041,7 +2064,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
20412064NARY_SHAPE_INFER(tosa::LogicalXorOp)
20422065NARY_SHAPE_INFER(tosa::MaximumOp)
20432066NARY_SHAPE_INFER(tosa::MinimumOp)
2044- NARY_SHAPE_INFER(tosa::NegateOp)
20452067NARY_SHAPE_INFER(tosa::PowOp)
20462068NARY_SHAPE_INFER(tosa::ReciprocalOp)
20472069NARY_SHAPE_INFER(tosa::RescaleOp)
@@ -2055,6 +2077,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
20552077NARY_SHAPE_INFER(tosa::SigmoidOp)
20562078#undef PRED_SHAPE_INFER
20572079
2080+ LogicalResult tosa::NegateOp::inferReturnTypeComponents (
2081+ MLIRContext *context, ::std::optional<Location> location,
2082+ NegateOp::Adaptor adaptor,
2083+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2084+ ShapeAdaptor inputShape (adaptor.getInput1 ().getType ());
2085+ inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
2086+ return success ();
2087+ }
2088+
2089+ LogicalResult tosa::NegateOp::verify () {
2090+ // Verify same element type
2091+ const Type input1Type = getInput1 ().getType ();
2092+ const Type outputType = getOutput ().getType ();
2093+ if (verifySameElementTypes (*this , input1Type, outputType).failed ())
2094+ return failure ();
2095+
2096+ // Verify same shape
2097+ const SmallVector<Type, 2 > types = {input1Type, outputType};
2098+ if (failed (verifyCompatibleShapes (types)))
2099+ return emitOpError () << " requires the same shape for input1 and output" ;
2100+
2101+ const Type input1EType = getStorageElementTypeOrSelf (getInput1 ().getType ());
2102+ const Type input1ZpEType =
2103+ getStorageElementTypeOrSelf (getInput1Zp ().getType ());
2104+ if (input1EType != input1ZpEType) {
2105+ return emitOpError (" expect both input1 and its zero point are the same "
2106+ " element type, got " )
2107+ << input1EType << " and " << input1ZpEType;
2108+ }
2109+ const Type outputEType = getStorageElementTypeOrSelf (getOutput ().getType ());
2110+ const Type outputZpEType =
2111+ getStorageElementTypeOrSelf (getOutputZp ().getType ());
2112+ if (outputEType != outputZpEType) {
2113+ return emitOpError (" expect both output and its zero point are the same "
2114+ " element type, got " )
2115+ << outputEType << " and " << outputZpEType;
2116+ }
2117+
2118+ FailureOr<int64_t > maybeIZp = getInput1ZeroPoint ();
2119+ if (succeeded (maybeIZp) && verifyInput1ZeroPoint (*maybeIZp).failed ())
2120+ return failure ();
2121+
2122+ FailureOr<int64_t > maybeOZp = getOutputZeroPoint ();
2123+ if (succeeded (maybeOZp) && verifyOutputZeroPoint (*maybeOZp).failed ())
2124+ return failure ();
2125+
2126+ return success ();
2127+ }
2128+
20582129static LogicalResult poolingInferReturnTypes (
20592130 ShapeAdaptor inputShape, ArrayRef<int64_t > kernel, ArrayRef<int64_t > stride,
20602131 ArrayRef<int64_t > pad,
0 commit comments