@@ -30,7 +30,7 @@ class AnnotateQparamsPass(ExportPass):
3030 and add Q->DQ after removing all the Q->DQs.
3131 """
3232
33- deliver_nodes = {
33+ propagate_nodes = {
3434 exir_ops .edge .aten .view_copy .default ,
3535 exir_ops .edge .aten .permute_copy .default ,
3636 exir_ops .edge .aten .squeeze_copy .default ,
@@ -83,7 +83,7 @@ def _impl(node: Node, res_list: List[Node]):
8383 _impl (user , res_list )
8484 return res_list
8585
86- def _deliver_quant_params (self , node : Node ):
86+ def _propagate_quant_params (self , node : Node ):
8787 assert (
8888 quantize_attrs := node .meta .get ("quantize_attrs" )
8989 ), "Must be annotated node."
@@ -98,25 +98,25 @@ def _deliver_quant_params(self, node: Node):
9898 ):
9999 break
100100 node = user
101- # Case1: ...-q-dq(cur)-deliver_node -node(not d-dq)
102- # Case2: deliver_node(delivered)-deliver_node -node(not q-dq)
101+ # Case1: ...-q-dq(cur)-propagate_node -node(not d-dq)
102+ # Case2: propagate_node(propagateed)-propagate_node -node(not q-dq)
103103 for idx , user in enumerate (node .users .keys ()):
104- # For the branch who need to be requantized, we deliver the requantize params
104+ # For the branch who need to be requantized, we propagate the requantize params
105105 user_attrs = requantize_map .get (idx , quantize_attrs )
106- if user .target not in self .deliver_nodes :
106+ if user .target not in self .propagate_nodes :
107107 continue
108108 if len (user .users ) == 1 :
109109 # Possibily no need for checking len(users)>1
110110 user_of_user = list (user .users )[0 ]
111- # node-q-dq-deliver -q-dq not need for delivery
111+ # node-q-dq-propagate -q-dq not need for propagatey
112112 if (
113113 user_of_user .target in QuantConstants .QUANT_OPS_KEY_MAP
114114 or user_of_user .target in QuantConstants .DEQUANT_OPS_KEY_MAP
115115 ):
116116 continue
117- # Deliver quant for node-q-dq-deliver_node -node(not qdq)
117+ # propagate quant for node-q-dq-propagate_node -node(not qdq)
118118 user .meta ["quantize_attrs" ] = user_attrs
119- self ._deliver_quant_params (user )
119+ self ._propagate_quant_params (user )
120120
121121 def _annotate_requantize (self , node : Node ):
122122 assert (
@@ -153,16 +153,7 @@ def _check_same(requant_obj, ori_obj) -> bool:
153153
154154 def _annotate (self , graph_module : GraphModule ):
155155 for node in graph_module .graph .nodes :
156- if key_map := QuantConstants .DEQUANT_OPS_KEY_MAP .get (node .target , None ):
157- # We will fold node with constant output in the future pass as a constant node
158- # example: Constant->Q->DQ->nodeN->Q->DQ, this seq will be folded to one
159- # We need to store the q-params from last DQ params for quantizing constant value
160- quant_attrs = self .get_quant_attrs (node , key_map )
161- node .meta ["quantize_attrs" ] = quant_attrs
162- continue
163- else :
164- key_map = QuantConstants .QUANT_OPS_KEY_MAP .get (node .target , None )
165- # ignore pre-quantized params now.
156+ key_map = QuantConstants .QUANT_OPS_KEY_MAP .get (node .target , None )
166157 if not key_map :
167158 continue
168159 source_node = node .args [0 ]
@@ -172,46 +163,15 @@ def _annotate(self, graph_module: GraphModule):
172163 ):
173164 # Currently, don't add quant info for d_qd node here.
174165 continue
166+ elif source_node .target == operator .getitem :
167+ source_node = source_node .args [0 ]
175168 quant_attrs = self .get_quant_attrs (node , key_map )
176- assert node .args [0 ].target != operator .getitem , "Not supported now."
177- source_node = node .args [0 ]
178169 source_node .meta ["quantize_attrs" ] = quant_attrs
179170 self ._annotate_requantize (source_node )
180- self ._deliver_quant_params (source_node )
181-
182- def _annotate_real_out (self , graph_module : GraphModule ):
183- for output_nodes in filter (
184- lambda x : x .op == "output" , graph_module .graph .nodes
185- ):
186- output_nodes = list (output_nodes .args [0 ])
187- for idx , output_node in enumerate (output_nodes ):
188- if output_node .target not in [
189- * QuantConstants .QUANT_OPS_KEY_MAP .keys (),
190- * QuantConstants .DEQUANT_OPS_KEY_MAP .keys (),
191- ]:
192- continue
193- while output_node .args [0 ].target in [
194- * QuantConstants .QUANT_OPS_KEY_MAP .keys (),
195- * QuantConstants .DEQUANT_OPS_KEY_MAP .keys (),
196- ]:
197- output_node = output_node .args [0 ]
198- output_nodes [idx ] = output_node
199- for node in output_nodes :
200- if node .target in QuantConstants .QUANT_OPS_KEY_MAP :
201- node .args [0 ].meta ["real_out" ] = True
202- else :
203- node .meta ["real_out" ] = True
204-
205- def _annotate_real_in (self , graph_module : GraphModule ):
206- for in_node in filter (
207- lambda x : is_graph_input (self .edge_program , x ), graph_module .graph .nodes
208- ):
209- in_node .meta ["real_in" ] = True
171+ self ._propagate_quant_params (source_node )
210172
211173 def call (self , graph_module : GraphModule ):
212174 self ._annotate (graph_module )
213- self ._annotate_real_out (graph_module )
214- self ._annotate_real_in (graph_module )
215175 graph_module .recompile ()
216176 return PassResult (graph_module , True )
217177
@@ -223,7 +183,6 @@ def get_quant_attrs(
223183 for key , attr in zip (quant_attr_keys [1 :], quant_node .args [1 :]):
224184 # For channel-wise quantization, params are stored by buffer nodes.
225185 if isinstance (attr , torch .fx .Node ):
226- assert isinstance (attr .target , str ), "Not supported now. "
227186 attr = get_buffer (self .edge_program , attr )
228187 quant_attrs [key ] = attr
229188 quant_attrs ["target" ] = quant_node .target
0 commit comments