1212from torch ._ops import OpOverload
1313
1414from torch ._subclasses import FakeTensor
15- from torch .ao . quantization . fake_quantize import FixedQParamsFakeQuantize
15+ from torch .fx import Node
1616
17- from torch .ao .quantization .observer import FixedQParamsObserver
18- from torch .ao .quantization .quantizer import (
17+ from torchao .quantization .pt2e import FixedQParamsFakeQuantize , FixedQParamsObserver
18+ from torchao .quantization .pt2e .quantizer import (
19+ annotate_input_qspec_map ,
20+ annotate_output_qspec ,
1921 DerivedQuantizationSpec ,
2022 QuantizationAnnotation ,
2123 QuantizationSpec ,
2224 SharedQuantizationSpec ,
2325)
24- from torch .ao .quantization .quantizer .utils import (
25- _annotate_input_qspec_map ,
26- _annotate_output_qspec ,
27- )
28- from torch .fx import Node
2926
3027from .qconfig import (
3128 get_16a16w_qnn_ptq_config ,
@@ -643,19 +640,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
643640 return
644641
645642 # TODO current only support 16a16w
646- _annotate_input_qspec_map (
643+ annotate_input_qspec_map (
647644 node ,
648645 act_node ,
649646 quantization_config .input_activation ,
650647 )
651648
652- _annotate_input_qspec_map (
649+ annotate_input_qspec_map (
653650 node ,
654651 weight_node ,
655652 quantization_config .input_activation ,
656653 )
657654 nodes_to_mark_annotated = [node ]
658- _annotate_output_qspec (node , quantization_config .output_activation )
655+ annotate_output_qspec (node , quantization_config .output_activation )
659656 _mark_nodes_as_annotated (nodes_to_mark_annotated )
660657
661658
@@ -844,25 +841,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
844841 if _is_annotated ([node ]):
845842 return
846843
847- _annotate_input_qspec_map (
844+ annotate_input_qspec_map (
848845 node ,
849846 act_node ,
850847 quantization_config .input_activation ,
851848 )
852- _annotate_input_qspec_map (
849+ annotate_input_qspec_map (
853850 node ,
854851 weight_node ,
855852 quantization_config .weight ,
856853 )
857854 nodes_to_mark_annotated = [node , weight_node ]
858855 if bias_node :
859- _annotate_input_qspec_map (
856+ annotate_input_qspec_map (
860857 node ,
861858 bias_node ,
862859 quantization_config .bias ,
863860 )
864861 nodes_to_mark_annotated .append (bias_node )
865- _annotate_output_qspec (node , quantization_config .output_activation )
862+ annotate_output_qspec (node , quantization_config .output_activation )
866863 _mark_nodes_as_annotated (nodes_to_mark_annotated )
867864
868865
@@ -1027,12 +1024,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10271024 if _is_annotated ([node ]):
10281025 return
10291026
1030- _annotate_input_qspec_map (
1027+ annotate_input_qspec_map (
10311028 node ,
10321029 act_node ,
10331030 quantization_config .input_activation ,
10341031 )
1035- _annotate_input_qspec_map (
1032+ annotate_input_qspec_map (
10361033 node ,
10371034 weight_node ,
10381035 quantization_config .weight ,
@@ -1043,9 +1040,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10431040 bias_config = quantization_config .bias (node )
10441041 else :
10451042 bias_config = quantization_config .bias
1046- _annotate_input_qspec_map (node , bias_node , bias_config )
1043+ annotate_input_qspec_map (node , bias_node , bias_config )
10471044 nodes_to_mark_annotated .append (bias_node )
1048- _annotate_output_qspec (node , quantization_config .output_activation )
1045+ annotate_output_qspec (node , quantization_config .output_activation )
10491046 _mark_nodes_as_annotated (nodes_to_mark_annotated )
10501047
10511048 # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
@@ -1063,29 +1060,29 @@ def annotate_batch_and_instance_norm(
10631060 return
10641061
10651062 annotated_args = [act ]
1066- _annotate_input_qspec_map (
1063+ annotate_input_qspec_map (
10671064 node ,
10681065 act ,
10691066 quantization_config .input_activation ,
10701067 )
10711068 # QNN requires uint8 instead of int8 in 'weight' config
10721069 if weight is not None :
1073- _annotate_input_qspec_map (
1070+ annotate_input_qspec_map (
10741071 node ,
10751072 weight ,
10761073 quantization_config .input_activation ,
10771074 )
10781075 annotated_args .append (weight )
10791076
10801077 if bias is not None :
1081- _annotate_input_qspec_map (
1078+ annotate_input_qspec_map (
10821079 node ,
10831080 bias ,
10841081 quantization_config .bias ,
10851082 )
10861083 annotated_args .append (bias )
10871084
1088- _annotate_output_qspec (node , quantization_config .output_activation )
1085+ annotate_output_qspec (node , quantization_config .output_activation )
10891086 _mark_nodes_as_annotated ([node , * annotated_args ])
10901087
10911088
@@ -1095,7 +1092,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
10951092 return
10961093
10971094 if _is_float_tensor (node ):
1098- _annotate_output_qspec (node , quantization_config .output_activation )
1095+ annotate_output_qspec (node , quantization_config .output_activation )
10991096 _mark_nodes_as_annotated ([node ])
11001097
11011098
@@ -1111,32 +1108,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
11111108 return
11121109 input_act_qspec = quantization_config .input_activation
11131110
1114- _annotate_input_qspec_map (
1111+ annotate_input_qspec_map (
11151112 node ,
11161113 act_node ,
11171114 input_act_qspec ,
11181115 )
11191116 if input_act_qspec .dtype == torch .int32 :
1120- _annotate_input_qspec_map (
1117+ annotate_input_qspec_map (
11211118 node ,
11221119 weight_node ,
11231120 get_16a16w_qnn_ptq_config ().weight ,
11241121 )
11251122 else :
1126- _annotate_input_qspec_map (
1123+ annotate_input_qspec_map (
11271124 node ,
11281125 weight_node ,
11291126 input_act_qspec ,
11301127 )
11311128 nodes_to_mark_annotated = [node , weight_node ]
11321129 if bias_node :
1133- _annotate_input_qspec_map (
1130+ annotate_input_qspec_map (
11341131 node ,
11351132 bias_node ,
11361133 quantization_config .bias ,
11371134 )
11381135 nodes_to_mark_annotated .append (bias_node )
1139- _annotate_output_qspec (node , quantization_config .output_activation )
1136+ annotate_output_qspec (node , quantization_config .output_activation )
11401137 _mark_nodes_as_annotated (nodes_to_mark_annotated )
11411138
11421139
0 commit comments