1111
1212import  torch 
1313
14- from  executorch .exir  import  EdgeProgramManager 
14+ from  executorch .exir  import  EdgeProgramManager ,  ExportedProgram 
1515from  executorch .exir .dialects ._ops  import  ops  as  exir_ops 
1616
1717from  executorch .exir .pass_base  import  ExportPass 
@@ -39,11 +39,33 @@ def quantize_input(
3939    if  len (target_placeholder .users ) !=  1 :
4040        raise  ValueError (f"Input { input_index }  )
4141    quantize  =  next (iter (target_placeholder .users ))
42+     if  quantize .target  not  in 
43+         exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
44+         torch .ops .quantized_decomposed .quantize_per_tensor .default ,
45+     ]:
46+         raise  ValueError (
47+             f"Input { input_index } { quantize .target }  
48+         )
49+ 
4250    if  (
4351        quantize .target 
44-         ! =exir_ops .edge .quantized_decomposed .quantize_per_tensor .default 
52+         = =exir_ops .edge .quantized_decomposed .quantize_per_tensor .default 
4553    ):
46-         raise  ValueError (f"Input { input_index }  )
54+         replacement_op_dequant  =  (
55+             exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default 
56+         )
57+         replacement_op_quant  =  (
58+             exir_ops .edge .quantized_decomposed .quantize_per_tensor .default 
59+         )
60+     elif  quantize .target  ==  torch .ops .quantized_decomposed .quantize_per_tensor .default :
61+         replacement_op_dequant  =  (
62+             torch .ops .quantized_decomposed .dequantize_per_tensor .default 
63+         )
64+         replacement_op_quant  =  (
65+             torch .ops .quantized_decomposed .quantize_per_tensor .default 
66+         )
67+     else :
68+         raise  ValueError (f"Invalid quantize op: { quantize .target }  )
4769
4870    # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op 
4971    need_requant  =  False 
@@ -83,7 +105,7 @@ def quantize_input(
83105
84106        with  exported_program .graph_module .graph .inserting_before (quantize ):
85107            input_dequant  =  exported_program .graph_module .graph .call_function (
86-                 exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default ,
108+                 replacement_op_dequant ,
87109                args = (
88110                    target_placeholder ,
89111                    * quant_args ,
@@ -106,10 +128,8 @@ def quantize_input(
106128        logger .info (f"Modifying program to take quantized input at index { input_index }  )
107129        logger .info (f"Quantization parameters: { quant_args }  )
108130
109-         target_placeholder .meta ["val" ] =  (
110-             exir_ops .edge .quantized_decomposed .quantize_per_tensor .default (
111-                 target_placeholder .meta ["val" ], * quant_args 
112-             )
131+         target_placeholder .meta ["val" ] =  replacement_op_quant (
132+             target_placeholder .meta ["val" ], * quant_args 
113133        )
114134        quantize .replace_all_uses_with (quantize .args [0 ])
115135
@@ -138,10 +158,10 @@ def quantize_output(exported_program, output_index):
138158        )
139159
140160    target_output  =  output_list [output_index ]
141-     if  ( 
142-         target_output . target 
143-         !=   exir_ops . edge .quantized_decomposed .dequantize_per_tensor .default 
144-     ) :
161+     if  target_output . target   not   in  [ 
162+         exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default , 
163+         torch . ops .quantized_decomposed .dequantize_per_tensor .default , 
164+     ] :
145165        raise  ValueError ("Output {output_index} is not a dequantize op" )
146166
147167    dequant  =  target_output 
@@ -185,6 +205,7 @@ def __init__(
185205        edge_program_manager : EdgeProgramManager ,
186206        quantized_inputs_idx : Union [Dict [int , Dict [str , Any ]], List [int ]],
187207        method_name : Optional [str ] =  None ,
208+         exported_program : Optional [ExportedProgram ] =  None ,
188209    ):
189210        super ().__init__ ()
190211        self .edge_program_manager  =  edge_program_manager 
@@ -196,31 +217,49 @@ def __init__(
196217            for  idx  in  quantized_inputs_idx :
197218                self .quantized_inputs_idx_dict [idx ] =  None 
198219        self .param_prefix_name  =  method_name 
220+         self .exported_program  =  exported_program 
221+         self .quant_args  =  {}
199222
200-     def  call (self , graph_module : torch .fx .GraphModule ):
201-         for  i , qparams  in  self .quantized_inputs_idx_dict .items ():
202-             quant_args  =  quantize_input (
203-                 self .edge_program_manager .exported_program (), i , qparams 
204-             )
205- 
223+     def  edge_manager_update_quant_config_method (self , idx , quant_args ):
224+         if  self .edge_program_manager  is  not None :
206225            if  not  self .edge_program_manager ._config_methods :
207226                self .edge_program_manager ._config_methods  =  {}
208227
209228            self .edge_program_manager ._config_methods [
210-                 get_config_method_name (self .param_prefix_name , "input" , i , "scale" )
229+                 get_config_method_name (self .param_prefix_name , "input" , idx , "scale" )
211230            ] =  quant_args [0 ]
212-             self .edge_program_manager ._config_methods [   # pyre-ignore 
213-                 get_config_method_name (self .param_prefix_name , "input" , i , "zp" )
231+             self .edge_program_manager ._config_methods [
232+                 get_config_method_name (self .param_prefix_name , "input" , idx , "zp" )
214233            ] =  quant_args [1 ]
215234            self .edge_program_manager ._config_methods [
216-                 get_config_method_name (self .param_prefix_name , "input" , i , "quant_min" )
235+                 get_config_method_name (
236+                     self .param_prefix_name , "input" , idx , "quant_min" 
237+                 )
217238            ] =  quant_args [2 ]
218239            self .edge_program_manager ._config_methods [
219-                 get_config_method_name (self .param_prefix_name , "input" , i , "quant_max" )
240+                 get_config_method_name (
241+                     self .param_prefix_name , "input" , idx , "quant_max" 
242+                 )
220243            ] =  quant_args [3 ]
221244            self .edge_program_manager ._config_methods [
222-                 get_config_method_name (self .param_prefix_name , "input" , i , "dtype" )
245+                 get_config_method_name (self .param_prefix_name , "input" , idx , "dtype" )
223246            ] =  scalar_type_enum (quant_args [4 ])
247+ 
248+     def  edge_manager_update_quant_config_methods_all (self ):
249+         if  self .edge_program_manager  is  not None :
250+             for  idx , val  in  self .quant_args .items ():
251+                 self .edge_manager_update_quant_config_method (idx , val )
252+ 
253+     def  call (self , graph_module : torch .fx .GraphModule ):
254+         for  i , qparams  in  self .quantized_inputs_idx_dict .items ():
255+             exported_program  =  (
256+                 self .edge_program_manager .exported_program ()
257+                 if  self .edge_program_manager  is  not None 
258+                 else  self .exported_program 
259+             )
260+             self .quant_args [i ] =  quantize_input (exported_program , i , qparams )
261+             self .edge_manager_update_quant_config_method (i , self .quant_args [i ])
262+ 
224263        return  PassResult (graph_module , True )
225264
226265
@@ -230,35 +269,53 @@ def __init__(
230269        edge_program_manager : EdgeProgramManager ,
231270        quantized_outputs_idx_list : List [int ],
232271        method_name : Optional [str ] =  None ,
272+         exported_program : Optional [ExportedProgram ] =  None ,
233273    ):
234274        super ().__init__ ()
235275        self .edge_program_manager  =  edge_program_manager 
236276        self .quantized_outputs_idx_list  =  quantized_outputs_idx_list 
237277        self .param_prefix_name  =  method_name 
278+         self .exported_program  =  exported_program 
279+         self .dequant_args  =  {}
238280
239-     def  call (self , graph_module : torch .fx .GraphModule ):
240-         for  i  in  self .quantized_outputs_idx_list :
241-             dequant_args  =  quantize_output (
242-                 self .edge_program_manager .exported_program (), i 
243-             )  # noqa F841 
244- 
281+     def  edge_manager_update_quant_config_method (self , idx , dequant_args ):
282+         if  self .edge_program_manager  is  not None :
245283            if  not  self .edge_program_manager ._config_methods :
246284                self .edge_program_manager ._config_methods  =  {}
247285
248286            self .edge_program_manager ._config_methods [
249-                 get_config_method_name (self .param_prefix_name , "output" , i , "scale" )
287+                 get_config_method_name (self .param_prefix_name , "output" , idx , "scale" )
250288            ] =  dequant_args [0 ]
251-             self .edge_program_manager ._config_methods [   # pyre-ignore 
252-                 get_config_method_name (self .param_prefix_name , "output" , i , "zp" )
289+             self .edge_program_manager ._config_methods [
290+                 get_config_method_name (self .param_prefix_name , "output" , idx , "zp" )
253291            ] =  dequant_args [1 ]
254292            self .edge_program_manager ._config_methods [
255-                 get_config_method_name (self .param_prefix_name , "output" , i , "quant_min" )
293+                 get_config_method_name (
294+                     self .param_prefix_name , "output" , idx , "quant_min" 
295+                 )
256296            ] =  dequant_args [2 ]
257297            self .edge_program_manager ._config_methods [
258-                 get_config_method_name (self .param_prefix_name , "output" , i , "quant_max" )
298+                 get_config_method_name (
299+                     self .param_prefix_name , "output" , idx , "quant_max" 
300+                 )
259301            ] =  dequant_args [3 ]
260302            self .edge_program_manager ._config_methods [
261-                 get_config_method_name (self .param_prefix_name , "output" , i , "dtype" )
303+                 get_config_method_name (self .param_prefix_name , "output" , idx , "dtype" )
262304            ] =  scalar_type_enum (dequant_args [4 ])
263305
306+     def  edge_manager_update_quant_config_methods_all (self ):
307+         if  self .edge_program_manager  is  not None :
308+             for  idx , val  in  self .dequant_args .items ():
309+                 self .edge_manager_update_quant_config_method (idx , val )
310+ 
311+     def  call (self , graph_module : torch .fx .GraphModule ):
312+         for  i  in  self .quantized_outputs_idx_list :
313+             exported_program  =  (
314+                 self .edge_program_manager .exported_program ()
315+                 if  self .edge_program_manager  is  not None 
316+                 else  self .exported_program 
317+             )
318+             self .dequant_args [i ] =  quantize_output (exported_program , i )  # noqa F841 
319+             self .edge_manager_update_quant_config_method (i , self .dequant_args [i ])
320+ 
264321        return  PassResult (graph_module , True )
0 commit comments