@@ -49,12 +49,8 @@ def call_operator(self, op, args, kwargs, meta):
4949 )
5050
5151 # convolution with bias and activation is int16
52- # The bias is assumed to be quantized with the same quantization parameters as
53- # as the output of the convolution
5452 bias = args [2 ]
55- assert (
56- meta .data ["output_qparams" ][0 ].dtype == bias .data .dtype
57- ), "Bias needs to have same type as quantized output type"
53+
5854 no_bias_args = list (args )
5955 no_bias_args [2 ] = None
6056 # split up to convolution + bias
@@ -79,46 +75,30 @@ def call_operator(self, op, args, kwargs, meta):
7975 # The conv will get the output int48 scaled to int32 in serialization step.
8076 # To be able to add the bias we need to first scale (cast?) the output to int32.
8177 # The resulting i32 sum will then need to be scaled back to the output dtype.
82-
83- # calculate common rescale factor from convolution output and bias quantization
8478 output_qparams = cast (QuantArgs , meta .data ["output_qparams" ][0 ])
8579 conv_output_scale = output_qparams .scale
86- bias_qparams = cast (QuantArgs , meta .data ["input_qparams" ][2 ])
87- bias_scale = bias_qparams .scale
8880
89- common_scale = max (bias_scale , conv_output_scale )
90-
91- # calculate how we can rescale bias and conv to a common scale and maximize the output range
92- bias_rescale_factor = bias_scale / common_scale
93- conv_rescale_factor = conv_output_scale / common_scale
81+ bias_qparams = cast (QuantArgs , meta .data ["input_qparams" ][2 ])
82+ per_channel_quant = bias_qparams .per_channel
9483
95- # Either of conv output or bias now covers the full int16 range and the other one a smaller range.
96- # Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
97- # Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
98- # and then one for the sign bit.
99- bits_left_to_shift = 14
84+ if per_channel_quant :
85+ bias_scale = bias_qparams .get_scale_per_channel ()
86+ else :
87+ bias_scale = [bias_qparams .get_scale_per_tensor ()]
10088
101- # update rescale factors
102- bias_rescale_factor *= 1 << bits_left_to_shift
103- conv_rescale_factor *= 1 << bits_left_to_shift
89+ conv_rescale_factors = [1.0 ] * len (bias_scale )
90+ final_output_scale = [b / conv_output_scale for b in bias_scale ]
10491
10592 conv_output = super ().call_operator (
10693 exir_ops .backend .tosa .RESCALE .default ,
107- (convolution , torch .int32 , [conv_rescale_factor ], 0 , 0 ),
108- {},
109- new_meta ,
110- )
111-
112- bias_rescaled = super ().call_operator (
113- exir_ops .backend .tosa .RESCALE .default ,
114- (channel_bias , torch .int32 , [bias_rescale_factor ], 0 , 0 ),
94+ (convolution , torch .int32 , conv_rescale_factors , 0 , 0 ),
11595 {},
11696 new_meta ,
11797 )
11898
11999 add = super ().call_operator (
120100 exir_ops .edge .aten .add .Tensor ,
121- (conv_output , bias_rescaled ),
101+ (conv_output , channel_bias ),
122102 {},
123103 new_meta ,
124104 )
@@ -128,7 +108,7 @@ def call_operator(self, op, args, kwargs, meta):
128108 (
129109 add ,
130110 output_dtype ,
131- [( common_scale / ( conv_output_scale * ( 1 << bits_left_to_shift )))] ,
111+ final_output_scale ,
132112 0 ,
133113 0 ,
134114 ),
0 commit comments