@@ -34,14 +34,15 @@ Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter,
3434// rounding mode
3535Value buildRescale (PatternRewriter &rewriter, Operation *op,
3636 ShapedType output_type, Value input_val, double scale,
37- int64_t input_zp, int64_t output_zp, bool double_round ,
37+ int64_t input_zp, int64_t output_zp, StringRef rounding_mode ,
3838 bool scale32) {
3939 int32_t multiplier;
4040 int32_t shift;
4141
4242 int32_t scale_width = scale32 ? 32 : 16 ;
4343
44- computeMultiplierAndShift (scale, multiplier, shift, scale_width);
44+ if (!computeMultiplierAndShift (scale, multiplier, shift, scale_width))
45+ op->emitError (" buildRescale: shift must be in the range 2 <= shift <= 62" );
4546
4647 Value multiplier_val =
4748 buildRescaleMultiplier (scale32, rewriter, op, {multiplier});
@@ -52,11 +53,23 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
5253 bool input_unsigned = input_val.getType ().isUnsignedInteger ();
5354 bool output_unsigned = output_type.isUnsignedInteger ();
5455
56+ // Create input_zp matches the input type and output_zp matches the output
57+ // type of RescaleOp
58+ const auto input_zp_val = tosa::createZeroPointTensor (
59+ rewriter, op->getLoc (), dyn_cast<TensorType>(input_val.getType ()),
60+ input_zp);
61+ if (!input_zp_val.has_value ())
62+ op->emitError (" Failed to create input zero-point tensor for RescaleOp." );
63+
64+ const auto output_zp_val = tosa::createZeroPointTensor (
65+ rewriter, op->getLoc (), output_type, output_zp);
66+ if (!output_zp_val.has_value ())
67+ op->emitError (" Failed to create output zero-point tensor for RescaleOp." );
68+
5569 auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
5670 rewriter, op->getLoc (), output_type, input_val, multiplier_val, shift_val,
57- rewriter.getI32IntegerAttr (static_cast <int32_t >(input_zp)),
58- rewriter.getI32IntegerAttr (static_cast <int32_t >(output_zp)),
59- rewriter.getBoolAttr (scale32), rewriter.getBoolAttr (double_round),
71+ input_zp_val.value (), output_zp_val.value (),
72+ rewriter.getBoolAttr (scale32), rewriter.getStringAttr (rounding_mode),
6073 rewriter.getBoolAttr (false ), rewriter.getBoolAttr (input_unsigned),
6174 rewriter.getBoolAttr (output_unsigned));
6275
@@ -73,7 +86,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
7386 auto output_type = input_type.clone (rewriter.getI32Type ());
7487
7588 return buildRescale (rewriter, op, output_type, input_val, input_scale,
76- input_zp, 0 , false , true );
89+ input_zp, 0 , " SINGLE_ROUND " , true );
7790}
7891
7992// Creates a TOSA rescale op based on conv2d parameters.
@@ -96,6 +109,16 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
96109 bool input_unsigned = input_qtype.isUnsignedInteger ();
97110 bool output_unsigned = output_qtype.isUnsignedInteger ();
98111
112+ const auto input_zp_val = tosa::createZeroPointTensor (
113+ rewriter, op->getLoc (), input_type, static_cast <int64_t >(0 ));
114+ if (!input_zp_val.has_value ())
115+ op->emitError (" Failed to create input zero-point tensor for RescaleOp." );
116+
117+ const auto output_zp_val = tosa::createZeroPointTensor (
118+ rewriter, op->getLoc (), output_type, output_zp);
119+ if (!output_zp_val.has_value ())
120+ op->emitError (" Failed to create output zero-point tensor for RescaleOp." );
121+
99122 if (auto weight_per_tensor_qtype =
100123 dyn_cast<mlir::quant::UniformQuantizedType>(
101124 weight_type.getElementType ())) {
@@ -107,7 +130,11 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
107130
108131 double op_tensor_scale = (input_scale * weight_scale) / output_scale;
109132
110- computeMultiplierAndShift (op_tensor_scale, multiplier, shift, scale_width);
133+ if (!computeMultiplierAndShift (op_tensor_scale, multiplier, shift,
134+ scale_width))
135+ op->emitError (
136+ " buildRescaleOpConvOutput: shift must be in the range 2 <= shift <= "
137+ " 62" );
111138
112139 Value multiplier_val =
113140 buildRescaleMultiplier (scale32, rewriter, op, {multiplier});
@@ -117,10 +144,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
117144
118145 auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
119146 rewriter, op->getLoc (), output_type, conv_val, multiplier_val,
120- shift_val, rewriter.getI32IntegerAttr (0 ),
121- rewriter.getI32IntegerAttr (output_zp), rewriter.getBoolAttr (scale32),
122- rewriter.getBoolAttr (true ), rewriter.getBoolAttr (false ),
123- rewriter.getBoolAttr (input_unsigned),
147+ shift_val, input_zp_val.value (), output_zp_val.value (),
148+ rewriter.getBoolAttr (scale32), rewriter.getStringAttr (" DOUBLE_ROUND" ),
149+ rewriter.getBoolAttr (false ), rewriter.getBoolAttr (input_unsigned),
124150 rewriter.getBoolAttr (output_unsigned));
125151
126152 return rescale_op.getResult ();
@@ -136,17 +162,16 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
136162 weight_per_channel_qtype.getScales ().begin (),
137163 weight_per_channel_qtype.getScales ().end ());
138164
139- int64_t output_zp = output_qtype.getZeroPoint ();
140- double output_scale = output_qtype.getScale ();
141-
142165 for (double weight_scale : weight_scale_arr) {
143166 int32_t multiplier;
144167 int32_t shift;
145168
146169 double op_channel_scale = (input_scale * weight_scale) / output_scale;
147170
148- computeMultiplierAndShift (op_channel_scale, multiplier, shift,
149- scale_width);
171+ if (!computeMultiplierAndShift (op_channel_scale, multiplier, shift, 32 ))
172+ op->emitError (
173+ " buildRescaleOpConvOutput: shift must be in the range 2 <= shift "
174+ " <= 62" );
150175
151176 multiplier_arr.push_back (multiplier);
152177 shift_arr.push_back (static_cast <int8_t >(shift));
@@ -161,10 +186,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
161186
162187 auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
163188 rewriter, op->getLoc (), output_type, conv_val, multiplier_val,
164- shift_val, rewriter.getI32IntegerAttr (0 ),
165- rewriter.getI32IntegerAttr (output_zp), rewriter.getBoolAttr (scale32),
166- rewriter.getBoolAttr (true ), rewriter.getBoolAttr (true ),
167- rewriter.getBoolAttr (input_unsigned),
189+ shift_val, input_zp_val.value (), output_zp_val.value (),
190+ rewriter.getBoolAttr (scale32), rewriter.getStringAttr (" DOUBLE_ROUND" ),
191+ rewriter.getBoolAttr (true ), rewriter.getBoolAttr (input_unsigned),
168192 rewriter.getBoolAttr (output_unsigned));
169193
170194 return rescale_op.getResult ();
0 commit comments