@@ -30,7 +30,7 @@ class AnnotateQparamsPass(ExportPass):
3030         and add Q->DQ after removing all the Q->DQs. 
3131    """ 
3232
33-     deliver_nodes  =  {
33+     propagate_nodes  =  {
3434        exir_ops .edge .aten .view_copy .default ,
3535        exir_ops .edge .aten .permute_copy .default ,
3636        exir_ops .edge .aten .squeeze_copy .default ,
@@ -83,7 +83,7 @@ def _impl(node: Node, res_list: List[Node]):
8383            _impl (user , res_list )
8484        return  res_list 
8585
86-     def  _deliver_quant_params (self , node : Node ):
86+     def  _propagate_quant_params (self , node : Node ):
8787        assert  (
8888            quantize_attrs  :=  node .meta .get ("quantize_attrs" )
8989        ), "Must be annotated node." 
@@ -98,25 +98,25 @@ def _deliver_quant_params(self, node: Node):
9898            ):
9999                break 
100100            node  =  user 
101-         # Case1: ...-q-dq(cur)-deliver_node -node(not d-dq) 
102-         # Case2: deliver_node(delivered)-deliver_node -node(not q-dq) 
101+         # Case1: ...-q-dq(cur)-propagate_node -node(not d-dq) 
102+         # Case2: propagate_node(propagateed)-propagate_node -node(not q-dq) 
103103        for  idx , user  in  enumerate (node .users .keys ()):
104-             # For the branch who need to be requantized, we deliver  the requantize params 
104+             # For the branch who need to be requantized, we propagate  the requantize params 
105105            user_attrs  =  requantize_map .get (idx , quantize_attrs )
106-             if  user .target  not  in   self .deliver_nodes :
106+             if  user .target  not  in   self .propagate_nodes :
107107                continue 
108108            if  len (user .users ) ==  1 :
109109                # Possibily no need for checking len(users)>1 
110110                user_of_user  =  list (user .users )[0 ]
111-                 # node-q-dq-deliver -q-dq not need for delivery  
111+                 # node-q-dq-propagate -q-dq not need for propagatey  
112112                if  (
113113                    user_of_user .target  in  QuantConstants .QUANT_OPS_KEY_MAP 
114114                    or  user_of_user .target  in  QuantConstants .DEQUANT_OPS_KEY_MAP 
115115                ):
116116                    continue 
117-             # Deliver  quant for node-q-dq-deliver_node -node(not qdq) 
117+             # propagate  quant for node-q-dq-propagate_node -node(not qdq) 
118118            user .meta ["quantize_attrs" ] =  user_attrs 
119-             self ._deliver_quant_params (user )
119+             self ._propagate_quant_params (user )
120120
121121    def  _annotate_requantize (self , node : Node ):
122122        assert  (
@@ -153,16 +153,7 @@ def _check_same(requant_obj, ori_obj) -> bool:
153153
154154    def  _annotate (self , graph_module : GraphModule ):
155155        for  node  in  graph_module .graph .nodes :
156-             if  key_map  :=  QuantConstants .DEQUANT_OPS_KEY_MAP .get (node .target , None ):
157-                 # We will fold node with constant output in the future pass as a constant node 
158-                 # example: Constant->Q->DQ->nodeN->Q->DQ, this seq will be folded to one 
159-                 # We need to store the q-params from last DQ params for quantizing constant value 
160-                 quant_attrs  =  self .get_quant_attrs (node , key_map )
161-                 node .meta ["quantize_attrs" ] =  quant_attrs 
162-                 continue 
163-             else :
164-                 key_map  =  QuantConstants .QUANT_OPS_KEY_MAP .get (node .target , None )
165-             # ignore pre-quantized params now. 
156+             key_map  =  QuantConstants .QUANT_OPS_KEY_MAP .get (node .target , None )
166157            if  not  key_map :
167158                continue 
168159            source_node  =  node .args [0 ]
@@ -172,46 +163,15 @@ def _annotate(self, graph_module: GraphModule):
172163            ):
173164                # Currently, don't add quant info for d_qd node here. 
174165                continue 
166+             elif  source_node .target  ==  operator .getitem :
167+                 source_node  =  source_node .args [0 ]
175168            quant_attrs  =  self .get_quant_attrs (node , key_map )
176-             assert  node .args [0 ].target  !=  operator .getitem , "Not supported now." 
177-             source_node  =  node .args [0 ]
178169            source_node .meta ["quantize_attrs" ] =  quant_attrs 
179170            self ._annotate_requantize (source_node )
180-             self ._deliver_quant_params (source_node )
181- 
182-     def  _annotate_real_out (self , graph_module : GraphModule ):
183-         for  output_nodes  in  filter (
184-             lambda  x : x .op  ==  "output" , graph_module .graph .nodes 
185-         ):
186-             output_nodes  =  list (output_nodes .args [0 ])
187-             for  idx , output_node  in  enumerate (output_nodes ):
188-                 if  output_node .target  not  in   [
189-                     * QuantConstants .QUANT_OPS_KEY_MAP .keys (),
190-                     * QuantConstants .DEQUANT_OPS_KEY_MAP .keys (),
191-                 ]:
192-                     continue 
193-                 while  output_node .args [0 ].target  in  [
194-                     * QuantConstants .QUANT_OPS_KEY_MAP .keys (),
195-                     * QuantConstants .DEQUANT_OPS_KEY_MAP .keys (),
196-                 ]:
197-                     output_node  =  output_node .args [0 ]
198-                 output_nodes [idx ] =  output_node 
199-             for  node  in  output_nodes :
200-                 if  node .target  in  QuantConstants .QUANT_OPS_KEY_MAP :
201-                     node .args [0 ].meta ["real_out" ] =  True 
202-                 else :
203-                     node .meta ["real_out" ] =  True 
204- 
205-     def  _annotate_real_in (self , graph_module : GraphModule ):
206-         for  in_node  in  filter (
207-             lambda  x : is_graph_input (self .edge_program , x ), graph_module .graph .nodes 
208-         ):
209-             in_node .meta ["real_in" ] =  True 
171+             self ._propagate_quant_params (source_node )
210172
211173    def  call (self , graph_module : GraphModule ):
212174        self ._annotate (graph_module )
213-         self ._annotate_real_out (graph_module )
214-         self ._annotate_real_in (graph_module )
215175        graph_module .recompile ()
216176        return  PassResult (graph_module , True )
217177
@@ -223,7 +183,6 @@ def get_quant_attrs(
223183        for  key , attr  in  zip (quant_attr_keys [1 :], quant_node .args [1 :]):
224184            # For channel-wise quantization, params are stored by buffer nodes. 
225185            if  isinstance (attr , torch .fx .Node ):
226-                 assert  isinstance (attr .target , str ), "Not supported now. " 
227186                attr  =  get_buffer (self .edge_program , attr )
228187            quant_attrs [key ] =  attr 
229188        quant_attrs ["target" ] =  quant_node .target 
0 commit comments