@@ -2226,7 +2226,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
22262226NARY_SHAPE_INFER(tosa::NegateOp)
22272227NARY_SHAPE_INFER(tosa::PowOp)
22282228NARY_SHAPE_INFER(tosa::ReciprocalOp)
2229- NARY_SHAPE_INFER(tosa::RescaleOp)
22302229NARY_SHAPE_INFER(tosa::ReverseOp)
22312230NARY_SHAPE_INFER(tosa::RsqrtOp)
22322231NARY_SHAPE_INFER(tosa::SinOp)
@@ -2676,6 +2675,147 @@ LogicalResult TransposeConv2DOp::verify() {
26762675 return success ();
26772676}
26782677
2678+ LogicalResult RescaleOp::verify () {
2679+ auto inputType = llvm::dyn_cast<ShapedType>(getInput ().getType ());
2680+ if (!inputType) {
2681+ emitOpError (" expect shaped tensor for input, got " ) << getInput ().getType ();
2682+ return failure ();
2683+ }
2684+
2685+ auto inputElementType =
2686+ getStorageElementTypeOrSelf (inputType.getElementType ());
2687+ if (!mlir::isa<IntegerType>(inputElementType)) {
2688+ emitOpError (" expect input to have integer element type, got " )
2689+ << inputElementType;
2690+ return failure ();
2691+ }
2692+
2693+ auto outputType = llvm::dyn_cast<ShapedType>(getOutput ().getType ());
2694+ if (!outputType) {
2695+ emitOpError (" expect shaped tensor for output, got " )
2696+ << getOutput ().getType ();
2697+ return failure ();
2698+ }
2699+
2700+ auto outputElementType =
2701+ getStorageElementTypeOrSelf (outputType.getElementType ());
2702+ if (!mlir::isa<IntegerType>(outputElementType)) {
2703+ emitOpError (" expect output to have integer element type, got " )
2704+ << outputElementType;
2705+ return failure ();
2706+ }
2707+
2708+ auto input_zp = getInputZpAttr ().getInt ();
2709+ if (input_zp != 0 ) {
2710+ // only int8/uint8 and uint16 input can have non-zero input_zp
2711+ if (!inputElementType.isInteger (8 ) &&
2712+ !(inputElementType.isInteger (16 ) && getInputUnsigned ())) {
2713+ emitOpError (" expect input_zp of 0, got " ) << input_zp;
2714+ return failure ();
2715+ }
2716+ // input_zp must be either 0 or 32768 for uint16 input
2717+ if (inputElementType.isInteger (16 ) && getInputUnsigned () &&
2718+ input_zp != 32768 ) {
2719+ emitOpError (
2720+ " expect input_zp of 0 or 32768 for unsigned int16 input, got " )
2721+ << input_zp;
2722+ return failure ();
2723+ }
2724+ }
2725+
2726+ auto output_zp = getOutputZpAttr ().getInt ();
2727+ if (output_zp != 0 ) {
2728+ // only int8/uint8 and uint16 output can have non-zero output_zp
2729+ if (!outputElementType.isInteger (8 ) &&
2730+ !(outputElementType.isInteger (16 ) && getOutputUnsigned ())) {
2731+ emitOpError (" expect output_zp of 0, got " ) << output_zp;
2732+ return failure ();
2733+ }
2734+ // output_zp must be either 0 or 32768 for uint16 output
2735+ if (outputElementType.isInteger (16 ) && getOutputUnsigned () &&
2736+ output_zp != 32768 ) {
2737+ emitOpError (
2738+ " expect output_zp of 0 or 32768 for unsigned int16 output, got " )
2739+ << output_zp;
2740+ return failure ();
2741+ }
2742+ }
2743+
2744+ auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier ().getType ());
2745+ if (!multiplierType) {
2746+ emitOpError (" expect shaped tensor for multiplier, got " )
2747+ << getMultiplier ().getType ();
2748+ return failure ();
2749+ }
2750+
2751+ auto shiftType = llvm::dyn_cast<ShapedType>(getShift ().getType ());
2752+ if (!shiftType) {
2753+ emitOpError (" expect shaped tensor for shift, got " ) << getShift ().getType ();
2754+ return failure ();
2755+ }
2756+
2757+ // multiplier element type must be i32 for scale32 = true
2758+ if (getScale32 () && !multiplierType.getElementType ().isInteger (32 )) {
2759+ emitOpError (" expect i32 element type for multiplier for scale32=true, got " )
2760+ << multiplierType.getElementType ();
2761+ return failure ();
2762+ }
2763+
2764+ // multiplier element type must be i16 for scale32 = false
2765+ if (!getScale32 () && !multiplierType.getElementType ().isInteger (16 )) {
2766+ emitOpError (
2767+ " expect i16 element type for multiplier for scale32=false, got " )
2768+ << multiplierType.getElementType ();
2769+ return failure ();
2770+ }
2771+
2772+ if (!inputType.hasRank ())
2773+ return success ();
2774+
2775+ // multiplier/shift must have shape = {numChannels},
2776+ // where numChannel is 1 if per_channel = false
2777+ // otherwise numChannel is dimension in input shape's last axis
2778+ int64_t numChannels = 1 ;
2779+ if (getPerChannel ()) {
2780+ numChannels = inputType.getDimSize (inputType.getRank () - 1 );
2781+ }
2782+
2783+ if (!multiplierType.hasRank ())
2784+ return success ();
2785+
2786+ ArrayRef<int64_t > multiplierShape = multiplierType.getShape ();
2787+ // multiplier input has rank 1 by dialect definition
2788+ if (multiplierShape[0 ] != ShapedType::kDynamic &&
2789+ multiplierShape[0 ] != numChannels) {
2790+ emitOpError (" expect shape of { " )
2791+ << numChannels << " } for multiplier input, got { "
2792+ << multiplierShape[0 ] << " }" ;
2793+ return failure ();
2794+ }
2795+
2796+ if (!shiftType.hasRank ())
2797+ return success ();
2798+
2799+ ArrayRef<int64_t > shiftShape = shiftType.getShape ();
2800+ // shift input has rank 1 by dialect definition
2801+ if (shiftShape[0 ] != ShapedType::kDynamic && shiftShape[0 ] != numChannels) {
2802+ emitOpError (" expect shape of { " )
2803+ << numChannels << " } for shift input, got { " << shiftShape[0 ] << " }" ;
2804+ return failure ();
2805+ }
2806+
2807+ return success ();
2808+ }
2809+
2810+ LogicalResult RescaleOp::inferReturnTypeComponents (
2811+ MLIRContext *context, ::std::optional<Location> location,
2812+ RescaleOp::Adaptor adaptor,
2813+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2814+ ShapeAdaptor inputShape (adaptor.getInput ().getType ());
2815+ inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
2816+ return success ();
2817+ }
2818+
26792819LogicalResult IfOp::inferReturnTypeComponents (
26802820 MLIRContext *context, ::std::optional<Location> location,
26812821 IfOp::Adaptor adaptor,
0 commit comments