1212from torch ._ops import OpOverload
1313
1414from torch ._subclasses import FakeTensor
15- from torch .ao . quantization . fake_quantize import FixedQParamsFakeQuantize
15+ from torch .fx import Node
1616
17- from torch .ao .quantization .observer import FixedQParamsObserver
18- from torch .ao .quantization .quantizer import (
17+ from torchao .quantization .pt2e import FixedQParamsFakeQuantize , FixedQParamsObserver
18+ from torchao .quantization .pt2e .quantizer import (
19+ annotate_input_qspec_map ,
20+ annotate_output_qspec ,
1921 DerivedQuantizationSpec ,
2022 QuantizationAnnotation ,
2123 QuantizationSpec ,
2224 SharedQuantizationSpec ,
2325)
24- from torch .ao .quantization .quantizer .utils import (
25- _annotate_input_qspec_map ,
26- _annotate_output_qspec ,
27- )
28- from torch .fx import Node
2926
3027from .qconfig import (
3128 get_16a16w_qnn_ptq_config ,
@@ -618,19 +615,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
618615 return
619616
620617 # TODO current only support 16a16w
621- _annotate_input_qspec_map (
618+ annotate_input_qspec_map (
622619 node ,
623620 act_node ,
624621 quantization_config .input_activation ,
625622 )
626623
627- _annotate_input_qspec_map (
624+ annotate_input_qspec_map (
628625 node ,
629626 weight_node ,
630627 quantization_config .input_activation ,
631628 )
632629 nodes_to_mark_annotated = [node ]
633- _annotate_output_qspec (node , quantization_config .output_activation )
630+ annotate_output_qspec (node , quantization_config .output_activation )
634631 _mark_nodes_as_annotated (nodes_to_mark_annotated )
635632
636633
@@ -819,25 +816,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
819816 if _is_annotated ([node ]):
820817 return
821818
822- _annotate_input_qspec_map (
819+ annotate_input_qspec_map (
823820 node ,
824821 act_node ,
825822 quantization_config .input_activation ,
826823 )
827- _annotate_input_qspec_map (
824+ annotate_input_qspec_map (
828825 node ,
829826 weight_node ,
830827 quantization_config .weight ,
831828 )
832829 nodes_to_mark_annotated = [node , weight_node ]
833830 if bias_node :
834- _annotate_input_qspec_map (
831+ annotate_input_qspec_map (
835832 node ,
836833 bias_node ,
837834 quantization_config .bias ,
838835 )
839836 nodes_to_mark_annotated .append (bias_node )
840- _annotate_output_qspec (node , quantization_config .output_activation )
837+ annotate_output_qspec (node , quantization_config .output_activation )
841838 _mark_nodes_as_annotated (nodes_to_mark_annotated )
842839
843840
@@ -1002,12 +999,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
1002999 if _is_annotated ([node ]):
10031000 return
10041001
1005- _annotate_input_qspec_map (
1002+ annotate_input_qspec_map (
10061003 node ,
10071004 act_node ,
10081005 quantization_config .input_activation ,
10091006 )
1010- _annotate_input_qspec_map (
1007+ annotate_input_qspec_map (
10111008 node ,
10121009 weight_node ,
10131010 quantization_config .weight ,
@@ -1018,9 +1015,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10181015 bias_config = quantization_config .bias (node )
10191016 else :
10201017 bias_config = quantization_config .bias
1021- _annotate_input_qspec_map (node , bias_node , bias_config )
1018+ annotate_input_qspec_map (node , bias_node , bias_config )
10221019 nodes_to_mark_annotated .append (bias_node )
1023- _annotate_output_qspec (node , quantization_config .output_activation )
1020+ annotate_output_qspec (node , quantization_config .output_activation )
10241021 _mark_nodes_as_annotated (nodes_to_mark_annotated )
10251022
10261023 # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
@@ -1038,29 +1035,29 @@ def annotate_batch_and_instance_norm(
10381035 return
10391036
10401037 annotated_args = [act ]
1041- _annotate_input_qspec_map (
1038+ annotate_input_qspec_map (
10421039 node ,
10431040 act ,
10441041 quantization_config .input_activation ,
10451042 )
10461043 # QNN requires uint8 instead of int8 in 'weight' config
10471044 if weight is not None :
1048- _annotate_input_qspec_map (
1045+ annotate_input_qspec_map (
10491046 node ,
10501047 weight ,
10511048 quantization_config .input_activation ,
10521049 )
10531050 annotated_args .append (weight )
10541051
10551052 if bias is not None :
1056- _annotate_input_qspec_map (
1053+ annotate_input_qspec_map (
10571054 node ,
10581055 bias ,
10591056 quantization_config .bias ,
10601057 )
10611058 annotated_args .append (bias )
10621059
1063- _annotate_output_qspec (node , quantization_config .output_activation )
1060+ annotate_output_qspec (node , quantization_config .output_activation )
10641061 _mark_nodes_as_annotated ([node , * annotated_args ])
10651062
10661063
@@ -1070,7 +1067,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
10701067 return
10711068
10721069 if _is_float_tensor (node ):
1073- _annotate_output_qspec (node , quantization_config .output_activation )
1070+ annotate_output_qspec (node , quantization_config .output_activation )
10741071 _mark_nodes_as_annotated ([node ])
10751072
10761073
@@ -1086,32 +1083,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
10861083 return
10871084 input_act_qspec = quantization_config .input_activation
10881085
1089- _annotate_input_qspec_map (
1086+ annotate_input_qspec_map (
10901087 node ,
10911088 act_node ,
10921089 input_act_qspec ,
10931090 )
10941091 if input_act_qspec .dtype == torch .int32 :
1095- _annotate_input_qspec_map (
1092+ annotate_input_qspec_map (
10961093 node ,
10971094 weight_node ,
10981095 get_16a16w_qnn_ptq_config ().weight ,
10991096 )
11001097 else :
1101- _annotate_input_qspec_map (
1098+ annotate_input_qspec_map (
11021099 node ,
11031100 weight_node ,
11041101 input_act_qspec ,
11051102 )
11061103 nodes_to_mark_annotated = [node , weight_node ]
11071104 if bias_node :
1108- _annotate_input_qspec_map (
1105+ annotate_input_qspec_map (
11091106 node ,
11101107 bias_node ,
11111108 quantization_config .bias ,
11121109 )
11131110 nodes_to_mark_annotated .append (bias_node )
1114- _annotate_output_qspec (node , quantization_config .output_activation )
1111+ annotate_output_qspec (node , quantization_config .output_activation )
11151112 _mark_nodes_as_annotated (nodes_to_mark_annotated )
11161113
11171114
0 commit comments