88
99import serializer .tosa_serializer as ts
1010import torch
11+ from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
12+ get_input_qparams ,
13+ get_output_qparams ,
14+ )
1115from executorch .backends .arm .operators .node_visitor import (
1216 NodeVisitor ,
1317 register_node_visitor ,
1418)
1519from executorch .backends .arm .tosa_mapping import TosaArg
16- from executorch .backends .arm .tosa_quant_utils import (
17- build_rescale_conv_output ,
18- get_quant_arg_downstream ,
19- get_quant_arg_upstream ,
20- )
20+ from executorch .backends .arm .tosa_quant_utils import build_rescale_conv_output
2121from executorch .backends .arm .tosa_utils import build_reshape , tosa_shape
2222
2323from serializer .tosa_serializer import TosaOp
@@ -57,9 +57,6 @@ def define_node(
5757 ) -> None :
5858 input , weight , bias , stride , pad , dilation , _ , _ , group = inputs
5959
60- # Currently only int8 is supported in quantized types.
61- actual_out_type = ts .DType .INT8 if is_quant_node else output .dtype
62-
6360 # Get the attributes of convolution.
6461 attr = ts .TosaSerializerAttribute ()
6562 pad_attr = [val for val in pad .special for _ in (0 , 1 )]
@@ -82,9 +79,11 @@ def define_node(
8279 dilation_attr [1 ],
8380 )
8481
85- input_zp = (
86- get_quant_arg_upstream (node .all_input_nodes [0 ]).zp if is_quant_node else 0
87- )
82+ input_zp = 0
83+ if inputs [0 ].dtype == ts .DType .INT8 :
84+ # int8 input requires quantization information
85+ input_qparams = get_input_qparams (node )
86+ input_zp = input_qparams [0 ].zp
8887
8988 attr .ConvAttribute (
9089 pad = pad_attr ,
@@ -100,16 +99,22 @@ def define_node(
10099 # Create a zero bias tensor if not presented
101100 out_channels = weight .shape [0 ]
102101 bias_name = "bias" + node .name .split ("default" , 1 )[1 ]
102+ bias_type = output .dtype
103+ if output .dtype == ts .DType .INT8 :
104+ # Conv is quantized to int8, but the TOSA operator has
105+ # output type int32, and the bias must be the same type
106+ # as the TOSA output type
107+ bias_type = ts .DType .INT32
103108 bias = tosa_graph .addConst (
104109 [out_channels ],
105- ts . DType . INT32 if is_quant_node else output . dtype ,
110+ bias_type ,
106111 [0 ] * out_channels ,
107112 name = bias_name ,
108113 )
109114
110115 # The output type is int32 when input type is int8.
111116 conv2d_output_name = output .name
112- if is_quant_node :
117+ if output . dtype == ts . DType . INT8 :
113118 conv2d_res = tosa_graph .addIntermediate (
114119 tosa_shape (output .shape , output .dim_order ), ts .DType .INT32
115120 )
@@ -132,7 +137,7 @@ def define_node(
132137
133138 weight_reshaped = tosa_graph .addIntermediate (
134139 weight_post_shape ,
135- ts . DType . INT8 if is_quant_node else weight .dtype ,
140+ weight .dtype ,
136141 )
137142 build_reshape (
138143 tosa_graph , weight .name , weight_post_shape , weight_reshaped .name
@@ -157,20 +162,19 @@ def define_node(
157162
158163 # For quantized convolution, rescale the output value back to the same
159164 # integer value domain of the next op. Otherwise return float32 output.
160- if is_quant_node :
165+ if inputs [ 0 ]. dtype == ts . DType . INT8 :
161166 # Get scale_factor from input, weight, and output.
162- input_scale = get_quant_arg_upstream (node .all_input_nodes [0 ]).scale
163- weight_scale = get_quant_arg_upstream (node .all_input_nodes [1 ]).scale
164- output_qargs = get_quant_arg_downstream (list (node .users )[0 ])
165-
167+ input_scale = input_qparams [0 ].scale
168+ weight_scale = input_qparams [1 ].scale
169+ output_qargs = get_output_qparams (node )
166170 build_rescale_conv_output (
167171 tosa_graph ,
168172 # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
169173 conv2d_res ,
170174 output .name ,
171- actual_out_type ,
175+ output . dtype ,
172176 input_scale ,
173177 weight_scale ,
174- output_qargs .scale ,
175- output_qargs .zp ,
178+ output_qargs [ 0 ] .scale ,
179+ output_qargs [ 0 ] .zp ,
176180 )
0 commit comments