@@ -227,6 +227,40 @@ def annotate_single_in_single_out(
227227 _annotated = True ,
228228 )
229229
230+ def annotate_single_in_share_out (
231+ node : Node , quantization_config : QuantizationConfig
232+ ) -> None :
233+
234+ input_qspec_map = {}
235+ input_act = node .args [0 ]
236+ input_qspec_map [input_act ] = quantization_config .input_activation
237+
238+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
239+ input_qspec_map = input_qspec_map ,
240+ output_qspec = SharedQuantizationSpec ((input_act , node )),
241+ _annotated = True ,
242+ )
243+
244+ def annotate_stack (node : Node , quantization_config : QuantizationConfig ) -> None :
245+ input_nodes = node .args [0 ]
246+
247+ first_input_node = input_nodes [0 ]
248+ input_qspec_map = {}
249+ input_qspec_map [first_input_node ] = quantization_config .input_activation
250+ share_qparams_with_input_act0_qspec = SharedQuantizationSpec (
251+ (first_input_node , node )
252+ )
253+
254+ for input_node in input_nodes [1 :]:
255+ if input_node not in input_qspec_map :
256+ input_qspec_map [input_node ] = share_qparams_with_input_act0_qspec
257+
258+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
259+ input_qspec_map = input_qspec_map ,
260+ output_qspec = share_qparams_with_input_act0_qspec ,
261+ _annotated = True ,
262+ )
263+
230264 def annotate_matmul_input1 (node : Node ):
231265 quantization_config_8a8w = get_8a8w_qnn_ptq_config (
232266 act_symmetric = True , act_observer = MinMaxObserver
@@ -247,6 +281,12 @@ def annotate_matmul_input1(node: Node):
247281 ]:
248282 annotate_single_in_single_out (node , quantization_config_8a8w )
249283 node = node .args [0 ]
284+ elif node .target == torch .ops .aten .stack .default :
285+ annotate_stack (node , quantization_config_8a8w )
286+ node = node .args [0 ]
287+ elif node .target == torch .ops .aten .flatten .using_ints :
288+ annotate_single_in_share_out (node , quantization_config_8a8w )
289+ node = node .args [0 ]
250290 elif node .target == torch .ops .aten .cat .default :
251291 annotate_cat (node , quantization_config_8a8w )
252292 # For v, we tag 8a until conv op.
0 commit comments