1111
1212
1313from collections import defaultdict
14- from typing import Union
14+ from typing import Any , Union
1515
1616import torch
1717import torch .fx
18+ from torch .ao .quantization .pt2e .prepare import _get_edge_or_node_to_group_id
19+ from torch .ao .quantization .pt2e .prepare import _get_edge_or_node_to_qspec
1820from torch .ao .quantization .quantizer import Quantizer as TorchAOQuantizer
1921from torch .ao .quantization .quantizer .quantizer import QuantizationSpec
20- from torch .ao .quantization .quantizer .quantizer import QuantizationSpecBase
2122from torch .ao .quantization .quantizer .quantizer import SharedQuantizationSpec
2223
2324import nncf
2425from nncf .common .graph .graph import NNCFGraph
25- from nncf .common .logging import nncf_logger
2626from nncf .common .quantization .quantizer_setup import ActivationQuantizationInsertionPoint
2727from nncf .common .quantization .quantizer_setup import QuantizationPointBase
2828from nncf .common .quantization .quantizer_setup import SingleConfigQuantizationPoint
@@ -73,6 +73,15 @@ def _get_quantization_points(
7373 annotated_model : torch .fx .GraphModule ,
7474 qconfig : QuantizerConfig ,
7575 ) -> list [QuantizationPointBase ]:
76+ """
77+ Creates quantization points based on the nodes and edges.
78+
79+ :param from_node: The originating node in the computation graph.
80+ :param to_nodes: The list of destination nodes of the from_node.
81+ :param annotated_model: The torch.fx.GraphModule instance.
82+ :param qconfig: The torch.ao quantization configuration.
83+ :return: A list of NNCF quantization points.
84+ """
7685 to_n = to_nodes [0 ]
7786 if from_node .op == "get_attr" :
7887 _ , metatype = GraphConverter .get_node_type_and_metatype (to_n , annotated_model )
@@ -95,78 +104,102 @@ def _get_quantization_points(
95104 return qps
96105
97106 @staticmethod
98- def _get_node_args (node : torch .fx .Node ):
107+ def _get_node_args (node : torch .fx .Node ) -> tuple [Any , ...]:
108+ """
109+ Correctly retrieves arguments of the given node.
110+
111+ :param node: The given node.
112+ :return: The arguments of the given node.
113+ """
99114 if node .target == torch .ops .aten .cat .default :
100115 return node .args [0 ]
101116 return node .args
102117
103118 @staticmethod
104- def get_quantizer_config_from_annotated_model (annotated_model : torch .fx .GraphModule ) -> SingleConfigQuantizerSetup :
105- edge_or_node_to_qspec = _get_edge_or_node_to_qspec (annotated_model )
106-
107- q_map = defaultdict (list )
108- for edge , qspec in edge_or_node_to_qspec .items ():
109- if not isinstance (edge , tuple ):
110- continue
111- from_n , to_n = edge
112- q_map [from_n ].append (to_n )
119+ def get_quantizer_config_from_annotated_model (annotated : torch .fx .GraphModule ) -> SingleConfigQuantizerSetup :
120+ edge_or_node_to_qspec = _get_edge_or_node_to_qspec (annotated )
121+ # Node means all output edges should be quantized.
122+ # Edge means only one edge should be quantized.
123+ edge_or_node_to_group_id = _get_edge_or_node_to_group_id (edge_or_node_to_qspec )
124+
125+ group_id_vs_edges = defaultdict (set )
126+ group_id_vs_qspec = {}
127+ for edge_or_node , group_id in edge_or_node_to_group_id .items ():
128+ target_edges = [edge_or_node ]
129+ if isinstance (edge_or_node , torch .fx .Node ):
130+ target_edges = []
131+ for user in edge_or_node .users :
132+ target_edges .append ((edge_or_node , user ))
133+ group_id_vs_edges [group_id ].update (target_edges )
134+ # All qspecs should be aligned after the _get_edge_or_node_to_group_id call
135+ group_id_vs_qspec [group_id ] = _unwrap_shared_qspec_safe (
136+ edge_or_node_to_qspec [edge_or_node ], edge_or_node_to_qspec
137+ )
113138
114139 q_setup = SingleConfigQuantizerSetup ()
115- for from_n , to_nodes in q_map .items ():
116- to_n = to_nodes [0 ]
117- qspec = edge_or_node_to_qspec [(from_n , to_n )]
140+ for group_id , edges in group_id_vs_edges .items ():
141+ qspec = group_id_vs_qspec [group_id ]
118142 if qspec is None :
119143 continue
120- if isinstance (qspec , QuantizationSpec ):
121- if qspec .qscheme in [torch .per_channel_affine , torch .per_channel_symmetric ]:
122- per_channel = True
123- elif qspec .qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
124- per_channel = False
125- else :
126- msg = f"Unknown qscheme: { qspec .qscheme } "
127- raise nncf .InternalError (msg )
128- signed = qspec .dtype is torch .int8
129- mode = (
130- QuantizationMode .SYMMETRIC
131- if qspec .qscheme in [torch .per_channel_symmetric , torch .per_tensor_symmetric ]
132- else QuantizationMode .ASYMMETRIC
133- )
134- qconfig = QuantizerConfig (mode = mode , signedness_to_force = signed , per_channel = per_channel )
135-
136- qps = TorchAOQuantizerAdapter ._get_quantization_points (from_n , to_nodes , annotated_model , qconfig )
137- for qp in qps :
138- q_setup .add_independent_quantization_point (qp )
139-
140- elif isinstance (qspec , SharedQuantizationSpec ):
141- # TODO(dlyakhov): Support SharedQuantizationSpec
142- nncf_logger .warning (
143- f"SharedQuantizationSpec is not supported yet; edges { from_n } -> { to_nodes } won't be quantized."
144- )
145- else :
144+ if not isinstance (qspec , QuantizationSpec ):
146145 msg = f"Unknown torch.ao quantization spec: { qspec } "
147146 raise nncf .InternalError (msg )
148147
148+ if qspec .qscheme in [torch .per_channel_affine , torch .per_channel_symmetric ]:
149+ per_channel = True
150+ elif qspec .qscheme in [torch .per_tensor_affine , torch .per_tensor_symmetric ]:
151+ per_channel = False
152+ else :
153+ msg = f"Unknown qscheme: { qspec .qscheme } "
154+ raise nncf .InternalError (msg )
155+
156+ signed = qspec .dtype is torch .int8
157+ mode = (
158+ QuantizationMode .SYMMETRIC
159+ if qspec .qscheme in [torch .per_channel_symmetric , torch .per_tensor_symmetric ]
160+ else QuantizationMode .ASYMMETRIC
161+ )
162+ narrow_range = qspec .quant_min % 2 != 0
163+ qconfig = QuantizerConfig (
164+ mode = mode , signedness_to_force = signed , per_channel = per_channel , narrow_range = narrow_range
165+ )
166+
167+ joined_edges = defaultdict (list )
168+ for edge in edges :
169+ joined_edges [edge [0 ]].append (edge [1 ])
170+
171+ qps = []
172+ for from_node , to_nodes in joined_edges .items ():
173+ qps .extend (TorchAOQuantizerAdapter ._get_quantization_points (from_node , to_nodes , annotated , qconfig ))
174+ qp_ids = []
175+ for qp in qps :
176+ qp_ids .append (q_setup .add_independent_quantization_point (qp ))
177+ if len (qp_ids ) > 1 :
178+ q_setup .register_unified_scale_group (qp_ids )
179+
149180 return q_setup
150181
151182
152- def _get_edge_or_node_to_qspec (
153- model : torch .fx .GraphModule ,
154- ) -> dict [EdgeOrNode , QuantizationSpecBase ]:
183+ def _unwrap_shared_qspec_safe (qspec : QuantizationSpec , edge_or_node_to_qspec : dict [EdgeOrNode , QuantizationSpec ]):
155184 """
156- Get a map from EdgeOrNode to quantization spec based on annotations on the nodes.
185+ Iteratively unwraps a given SharedQuantizationSpec to retrieve its actual QuantizationSpec.
186+ It detects cyclic dependencies and enforces a maximum depth limit to prevent infinite recursion.
157187
158- :param model: torch.fx.GraphModule instance.
159- :return: A map from EdgeOrNode to quantization spec based on annotations on the nodes.
188+ :param qspec: The quantization specification to unwrap.
189+ :param edge_or_node_to_qspec: A dictionary mapping EdgeOrNode instances to their respective QuantizationSpec.
190+ :return: The resolved QuantizationSpec.
160191 """
161- edge_or_node_to_qspec : dict [EdgeOrNode , QuantizationSpecBase ] = {}
162- for n in model .graph .nodes :
163- if hasattr (n , "meta" ) and "quantization_annotation" in n .meta :
164- qa = n .meta ["quantization_annotation" ]
165- for input_to_n , qspec in qa .input_qspec_map .items ():
166- input_edge = (input_to_n , n )
167- edge_or_node_to_qspec [input_edge ] = qspec
168- if qa .output_qspec is not None :
169- output_node = n
170- qspec = qa .output_qspec
171- edge_or_node_to_qspec [output_node ] = qspec
172- return edge_or_node_to_qspec
192+ MAX_DEPTH = 1000
193+ i = 0
194+ visited = []
195+ while i < MAX_DEPTH and isinstance (qspec , SharedQuantizationSpec ):
196+ if qspec .edge_or_node in visited :
197+ msg = f"A cycled dependency of the quantization spec is detected { visited + [qspec .edge_or_node ]} "
198+ raise RuntimeError (msg )
199+ visited .append (qspec .edge_or_node )
200+ qspec = edge_or_node_to_qspec [qspec .edge_or_node ]
201+ i += 1
202+ if i == MAX_DEPTH :
203+ msg = f"Shared qspecs referenced to each other more than the limit: { MAX_DEPTH } "
204+ raise RuntimeError (msg )
205+ return qspec
0 commit comments