12
12
from torch ._ops import OpOverload
13
13
14
14
from torch ._subclasses import FakeTensor
15
- from torch .ao . quantization . fake_quantize import FixedQParamsFakeQuantize
15
+ from torch .fx import Node
16
16
17
- from torch .ao .quantization .observer import FixedQParamsObserver
18
- from torch .ao .quantization .quantizer import (
17
+ from torchao .quantization .pt2e import FixedQParamsFakeQuantize , FixedQParamsObserver
18
+ from torchao .quantization .pt2e .quantizer import (
19
+ annotate_input_qspec_map ,
20
+ annotate_output_qspec ,
19
21
DerivedQuantizationSpec ,
20
22
QuantizationAnnotation ,
21
23
QuantizationSpec ,
22
24
SharedQuantizationSpec ,
23
25
)
24
- from torch .ao .quantization .quantizer .utils import (
25
- _annotate_input_qspec_map ,
26
- _annotate_output_qspec ,
27
- )
28
- from torch .fx import Node
29
26
30
27
from .qconfig import (
31
28
get_16a16w_qnn_ptq_config ,
@@ -643,19 +640,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
643
640
return
644
641
645
642
# TODO current only support 16a16w
646
- _annotate_input_qspec_map (
643
+ annotate_input_qspec_map (
647
644
node ,
648
645
act_node ,
649
646
quantization_config .input_activation ,
650
647
)
651
648
652
- _annotate_input_qspec_map (
649
+ annotate_input_qspec_map (
653
650
node ,
654
651
weight_node ,
655
652
quantization_config .input_activation ,
656
653
)
657
654
nodes_to_mark_annotated = [node ]
658
- _annotate_output_qspec (node , quantization_config .output_activation )
655
+ annotate_output_qspec (node , quantization_config .output_activation )
659
656
_mark_nodes_as_annotated (nodes_to_mark_annotated )
660
657
661
658
@@ -844,25 +841,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
844
841
if _is_annotated ([node ]):
845
842
return
846
843
847
- _annotate_input_qspec_map (
844
+ annotate_input_qspec_map (
848
845
node ,
849
846
act_node ,
850
847
quantization_config .input_activation ,
851
848
)
852
- _annotate_input_qspec_map (
849
+ annotate_input_qspec_map (
853
850
node ,
854
851
weight_node ,
855
852
quantization_config .weight ,
856
853
)
857
854
nodes_to_mark_annotated = [node , weight_node ]
858
855
if bias_node :
859
- _annotate_input_qspec_map (
856
+ annotate_input_qspec_map (
860
857
node ,
861
858
bias_node ,
862
859
quantization_config .bias ,
863
860
)
864
861
nodes_to_mark_annotated .append (bias_node )
865
- _annotate_output_qspec (node , quantization_config .output_activation )
862
+ annotate_output_qspec (node , quantization_config .output_activation )
866
863
_mark_nodes_as_annotated (nodes_to_mark_annotated )
867
864
868
865
@@ -1027,12 +1024,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
1027
1024
if _is_annotated ([node ]):
1028
1025
return
1029
1026
1030
- _annotate_input_qspec_map (
1027
+ annotate_input_qspec_map (
1031
1028
node ,
1032
1029
act_node ,
1033
1030
quantization_config .input_activation ,
1034
1031
)
1035
- _annotate_input_qspec_map (
1032
+ annotate_input_qspec_map (
1036
1033
node ,
1037
1034
weight_node ,
1038
1035
quantization_config .weight ,
@@ -1043,9 +1040,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
1043
1040
bias_config = quantization_config .bias (node )
1044
1041
else :
1045
1042
bias_config = quantization_config .bias
1046
- _annotate_input_qspec_map (node , bias_node , bias_config )
1043
+ annotate_input_qspec_map (node , bias_node , bias_config )
1047
1044
nodes_to_mark_annotated .append (bias_node )
1048
- _annotate_output_qspec (node , quantization_config .output_activation )
1045
+ annotate_output_qspec (node , quantization_config .output_activation )
1049
1046
_mark_nodes_as_annotated (nodes_to_mark_annotated )
1050
1047
1051
1048
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
@@ -1063,29 +1060,29 @@ def annotate_batch_and_instance_norm(
1063
1060
return
1064
1061
1065
1062
annotated_args = [act ]
1066
- _annotate_input_qspec_map (
1063
+ annotate_input_qspec_map (
1067
1064
node ,
1068
1065
act ,
1069
1066
quantization_config .input_activation ,
1070
1067
)
1071
1068
# QNN requires uint8 instead of int8 in 'weight' config
1072
1069
if weight is not None :
1073
- _annotate_input_qspec_map (
1070
+ annotate_input_qspec_map (
1074
1071
node ,
1075
1072
weight ,
1076
1073
quantization_config .input_activation ,
1077
1074
)
1078
1075
annotated_args .append (weight )
1079
1076
1080
1077
if bias is not None :
1081
- _annotate_input_qspec_map (
1078
+ annotate_input_qspec_map (
1082
1079
node ,
1083
1080
bias ,
1084
1081
quantization_config .bias ,
1085
1082
)
1086
1083
annotated_args .append (bias )
1087
1084
1088
- _annotate_output_qspec (node , quantization_config .output_activation )
1085
+ annotate_output_qspec (node , quantization_config .output_activation )
1089
1086
_mark_nodes_as_annotated ([node , * annotated_args ])
1090
1087
1091
1088
@@ -1095,7 +1092,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
1095
1092
return
1096
1093
1097
1094
if _is_float_tensor (node ):
1098
- _annotate_output_qspec (node , quantization_config .output_activation )
1095
+ annotate_output_qspec (node , quantization_config .output_activation )
1099
1096
_mark_nodes_as_annotated ([node ])
1100
1097
1101
1098
@@ -1111,32 +1108,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
1111
1108
return
1112
1109
input_act_qspec = quantization_config .input_activation
1113
1110
1114
- _annotate_input_qspec_map (
1111
+ annotate_input_qspec_map (
1115
1112
node ,
1116
1113
act_node ,
1117
1114
input_act_qspec ,
1118
1115
)
1119
1116
if input_act_qspec .dtype == torch .int32 :
1120
- _annotate_input_qspec_map (
1117
+ annotate_input_qspec_map (
1121
1118
node ,
1122
1119
weight_node ,
1123
1120
get_16a16w_qnn_ptq_config ().weight ,
1124
1121
)
1125
1122
else :
1126
- _annotate_input_qspec_map (
1123
+ annotate_input_qspec_map (
1127
1124
node ,
1128
1125
weight_node ,
1129
1126
input_act_qspec ,
1130
1127
)
1131
1128
nodes_to_mark_annotated = [node , weight_node ]
1132
1129
if bias_node :
1133
- _annotate_input_qspec_map (
1130
+ annotate_input_qspec_map (
1134
1131
node ,
1135
1132
bias_node ,
1136
1133
quantization_config .bias ,
1137
1134
)
1138
1135
nodes_to_mark_annotated .append (bias_node )
1139
- _annotate_output_qspec (node , quantization_config .output_activation )
1136
+ annotate_output_qspec (node , quantization_config .output_activation )
1140
1137
_mark_nodes_as_annotated (nodes_to_mark_annotated )
1141
1138
1142
1139
0 commit comments