44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7-
87from typing import Callable , List , Optional
98
109import torch
2423 # DATA LAYOUT OPS
2524 torch .ops .aten .squeeze .default ,
2625 torch .ops .aten .squeeze_copy .default ,
26+ torch .ops .aten .squeeze_copy .dim ,
27+ torch .ops .aten .squeeze .dim ,
28+ torch .ops .aten .squeeze .dims ,
2729 torch .ops .aten .unsqueeze .default ,
2830 torch .ops .aten .unsqueeze_copy .default ,
2931 torch .ops .aten .reshape .default ,
3335 # torch.ops.aten.view_as_complex_copy.default,
3436 # torch.ops.aten.view_as_real.default,
3537 # torch.ops.aten.view_as_real_copy.default,
38+ torch .ops .aten .view .default ,
3639 torch .ops .aten .view_copy .default ,
3740 torch .ops .aten .select .int ,
3841 torch .ops .aten .select_copy .int ,
3942 torch .ops .aten .slice .Tensor ,
4043 torch .ops .aten .slice_copy .Tensor ,
41- # 'concat' should be handled separately as it has a sequence of inputs and
42- # makes the implementation unnecessary complicated.
43- # torch.ops.aten.concat.default,
44+ torch .ops .aten .split .Tensor ,
45+ torch .ops .aten .split_with_sizes .default ,
4446 torch .ops .aten .transpose .Dimname ,
4547 torch .ops .aten .transpose .int ,
4648 torch .ops .aten .transpose_copy .int ,
4749 torch .ops .aten .tile .default ,
4850 torch .ops .aten .flip .default ,
51+ torch .ops .aten .cat .default ,
52+ torch .ops .aten .stack .default ,
4953]
5054
5155
@@ -66,15 +70,31 @@ def _annotate_generic(
6670 if arm_quantizer_utils .is_annotated (node ):
6771 continue
6872
69- input_node = node .args [0 ]
73+ input_acts = node .args [0 ]
74+
75+ # Check to see if there are multiple inputs.
76+ # this allows for stack/cat ops to be annotated
77+ # in a similar way.
78+ has_multi_inputs = isinstance (input_acts , list )
79+
80+ input_act0 = input_acts [0 ] if has_multi_inputs else input_acts
7081
7182 # Using a non-shared quantization spec here as a SharedQuantizationSpec
7283 # can lead to a recursion.
7384 _annotate_input_qspec_map (
74- node , input_node , quantization_config .get_input_act_qspec ()
85+ node , input_act0 , quantization_config .get_input_act_qspec ()
7586 )
76- _annotate_output_qspec (node , SharedQuantizationSpec ((input_node , node )))
87+ shared_with_input0_qspec = SharedQuantizationSpec ((input_act0 , node ))
88+
89+ if has_multi_inputs :
90+ # For the rest of the inputs, share qspec with first.
91+ for input_act in input_acts [1 :]:
92+ if input_act is not input_act0 :
93+ node .meta ["quantization_annotation" ].input_qspec_map [
94+ input_act
95+ ] = shared_with_input0_qspec
7796
97+ _annotate_output_qspec (node , shared_with_input0_qspec )
7898 arm_quantizer_utils .mark_nodes_as_annotated ([node ])
7999 annotated_partitions .append ([node ])
80100
0 commit comments