@@ -227,6 +227,41 @@ def annotate_single_in_single_out(
227227 _annotated = True ,
228228 )
229229
230+ def annotate_single_in_share_out (
231+ node : Node , quantization_config : QuantizationConfig
232+ ) -> None :
233+
234+ input_qspec_map = {}
235+ input_act = node .args [0 ]
236+ input_qspec_map [input_act ] = quantization_config .input_activation
237+
238+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
239+ input_qspec_map = input_qspec_map ,
240+ output_qspec = SharedQuantizationSpec ((input_act , node )),
241+ _annotated = True ,
242+ )
243+
244+ def annotate_stack (node : Node , quantization_config : QuantizationConfig ) -> None :
245+ input_nodes = node .args [0 ]
246+
247+ first_input_node = input_nodes [0 ]
248+ input_qspec_map = {}
249+ assert isinstance (first_input_node , Node )
250+ input_qspec_map [first_input_node ] = quantization_config .input_activation
251+ share_qparams_with_input_act0_qspec = SharedQuantizationSpec (
252+ (first_input_node , node )
253+ )
254+
255+ for input_node in input_nodes [1 :]:
256+ if input_node not in input_qspec_map :
257+ input_qspec_map [input_node ] = share_qparams_with_input_act0_qspec
258+
259+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
260+ input_qspec_map = input_qspec_map ,
261+ output_qspec = share_qparams_with_input_act0_qspec ,
262+ _annotated = True ,
263+ )
264+
230265 def annotate_matmul_input1 (node : Node ):
231266 quantization_config_8a8w = get_8a8w_qnn_ptq_config (
232267 act_symmetric = True , act_observer = MinMaxObserver
@@ -247,6 +282,12 @@ def annotate_matmul_input1(node: Node):
247282 ]:
248283 annotate_single_in_single_out (node , quantization_config_8a8w )
249284 node = node .args [0 ]
285+ elif node .target == torch .ops .aten .stack .default :
286+ annotate_stack (node , quantization_config_8a8w )
287+ node = node .args [0 ]
288+ elif node .target == torch .ops .aten .flatten .using_ints :
289+ annotate_single_in_share_out (node , quantization_config_8a8w )
290+ node = node .args [0 ]
250291 elif node .target == torch .ops .aten .cat .default :
251292 annotate_cat (node , quantization_config_8a8w )
252293 # For v, we tag 8a until conv op.
0 commit comments