1111
1212import torch
1313
14- from executorch .exir import EdgeProgramManager
14+ from executorch .exir import EdgeProgramManager , ExportedProgram
1515from executorch .exir .dialects ._ops import ops as exir_ops
1616
1717from executorch .exir .pass_base import ExportPass
@@ -39,11 +39,33 @@ def quantize_input(
3939 if len (target_placeholder .users ) != 1 :
4040 raise ValueError (f"Input { input_index } has more than one users" )
4141 quantize = next (iter (target_placeholder .users ))
42+ if quantize .target not in [
43+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
44+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
45+ ]:
46+ raise ValueError (
47+ f"Input { input_index } is not used by a quantize op. It's used by { quantize .target } "
48+ )
49+
4250 if (
4351 quantize .target
44- ! = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
52+ = = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
4553 ):
46- raise ValueError (f"Input { input_index } is not used by a quantize op" )
54+ replacement_op_dequant = (
55+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
56+ )
57+ replacement_op_quant = (
58+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
59+ )
60+ elif quantize .target == torch .ops .quantized_decomposed .quantize_per_tensor .default :
61+ replacement_op_dequant = (
62+ torch .ops .quantized_decomposed .dequantize_per_tensor .default
63+ )
64+ replacement_op_quant = (
65+ torch .ops .quantized_decomposed .quantize_per_tensor .default
66+ )
67+ else :
68+ raise ValueError (f"Invalid quantize op: { quantize .target } " )
4769
4870 # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op
4971 need_requant = False
@@ -83,7 +105,7 @@ def quantize_input(
83105
84106 with exported_program .graph_module .graph .inserting_before (quantize ):
85107 input_dequant = exported_program .graph_module .graph .call_function (
86- exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default ,
108+ replacement_op_dequant ,
87109 args = (
88110 target_placeholder ,
89111 * quant_args ,
@@ -106,10 +128,8 @@ def quantize_input(
106128 logger .info (f"Modifying program to take quantized input at index { input_index } " )
107129 logger .info (f"Quantization parameters: { quant_args } " )
108130
109- target_placeholder .meta ["val" ] = (
110- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default (
111- target_placeholder .meta ["val" ], * quant_args
112- )
131+ target_placeholder .meta ["val" ] = replacement_op_quant (
132+ target_placeholder .meta ["val" ], * quant_args
113133 )
114134 quantize .replace_all_uses_with (quantize .args [0 ])
115135
@@ -138,10 +158,10 @@ def quantize_output(exported_program, output_index):
138158 )
139159
140160 target_output = output_list [output_index ]
141- if (
142- target_output . target
143- != exir_ops . edge .quantized_decomposed .dequantize_per_tensor .default
144- ) :
161+ if target_output . target not in [
162+ exir_ops . edge . quantized_decomposed . dequantize_per_tensor . default ,
163+ torch . ops .quantized_decomposed .dequantize_per_tensor .default ,
164+ ] :
145165 raise ValueError ("Output {output_index} is not a dequantize op" )
146166
147167 dequant = target_output
@@ -185,6 +205,7 @@ def __init__(
185205 edge_program_manager : EdgeProgramManager ,
186206 quantized_inputs_idx : Union [Dict [int , Dict [str , Any ]], List [int ]],
187207 method_name : Optional [str ] = None ,
208+ exported_program : Optional [ExportedProgram ] = None ,
188209 ):
189210 super ().__init__ ()
190211 self .edge_program_manager = edge_program_manager
@@ -196,31 +217,49 @@ def __init__(
196217 for idx in quantized_inputs_idx :
197218 self .quantized_inputs_idx_dict [idx ] = None
198219 self .param_prefix_name = method_name
220+ self .exported_program = exported_program
221+ self .quant_args = {}
199222
200- def call (self , graph_module : torch .fx .GraphModule ):
201- for i , qparams in self .quantized_inputs_idx_dict .items ():
202- quant_args = quantize_input (
203- self .edge_program_manager .exported_program (), i , qparams
204- )
205-
223+ def edge_manager_update_quant_config_method (self , idx , quant_args ):
224+ if self .edge_program_manager is not None :
206225 if not self .edge_program_manager ._config_methods :
207226 self .edge_program_manager ._config_methods = {}
208227
209228 self .edge_program_manager ._config_methods [
210- get_config_method_name (self .param_prefix_name , "input" , i , "scale" )
229+ get_config_method_name (self .param_prefix_name , "input" , idx , "scale" )
211230 ] = quant_args [0 ]
212- self .edge_program_manager ._config_methods [ # pyre-ignore
213- get_config_method_name (self .param_prefix_name , "input" , i , "zp" )
231+ self .edge_program_manager ._config_methods [
232+ get_config_method_name (self .param_prefix_name , "input" , idx , "zp" )
214233 ] = quant_args [1 ]
215234 self .edge_program_manager ._config_methods [
216- get_config_method_name (self .param_prefix_name , "input" , i , "quant_min" )
235+ get_config_method_name (
236+ self .param_prefix_name , "input" , idx , "quant_min"
237+ )
217238 ] = quant_args [2 ]
218239 self .edge_program_manager ._config_methods [
219- get_config_method_name (self .param_prefix_name , "input" , i , "quant_max" )
240+ get_config_method_name (
241+ self .param_prefix_name , "input" , idx , "quant_max"
242+ )
220243 ] = quant_args [3 ]
221244 self .edge_program_manager ._config_methods [
222- get_config_method_name (self .param_prefix_name , "input" , i , "dtype" )
245+ get_config_method_name (self .param_prefix_name , "input" , idx , "dtype" )
223246 ] = scalar_type_enum (quant_args [4 ])
247+
248+ def edge_manager_update_quant_config_methods_all (self ):
249+ if self .edge_program_manager is not None :
250+ for idx , val in self .quant_args .items ():
251+ self .edge_manager_update_quant_config_method (idx , val )
252+
253+ def call (self , graph_module : torch .fx .GraphModule ):
254+ for i , qparams in self .quantized_inputs_idx_dict .items ():
255+ exported_program = (
256+ self .edge_program_manager .exported_program ()
257+ if self .edge_program_manager is not None
258+ else self .exported_program
259+ )
260+ self .quant_args [i ] = quantize_input (exported_program , i , qparams )
261+ self .edge_manager_update_quant_config_method (i , self .quant_args [i ])
262+
224263 return PassResult (graph_module , True )
225264
226265
@@ -230,35 +269,53 @@ def __init__(
230269 edge_program_manager : EdgeProgramManager ,
231270 quantized_outputs_idx_list : List [int ],
232271 method_name : Optional [str ] = None ,
272+ exported_program : Optional [ExportedProgram ] = None ,
233273 ):
234274 super ().__init__ ()
235275 self .edge_program_manager = edge_program_manager
236276 self .quantized_outputs_idx_list = quantized_outputs_idx_list
237277 self .param_prefix_name = method_name
278+ self .exported_program = exported_program
279+ self .dequant_args = {}
238280
239- def call (self , graph_module : torch .fx .GraphModule ):
240- for i in self .quantized_outputs_idx_list :
241- dequant_args = quantize_output (
242- self .edge_program_manager .exported_program (), i
243- ) # noqa F841
244-
281+ def edge_manager_update_quant_config_method (self , idx , dequant_args ):
282+ if self .edge_program_manager is not None :
245283 if not self .edge_program_manager ._config_methods :
246284 self .edge_program_manager ._config_methods = {}
247285
248286 self .edge_program_manager ._config_methods [
249- get_config_method_name (self .param_prefix_name , "output" , i , "scale" )
287+ get_config_method_name (self .param_prefix_name , "output" , idx , "scale" )
250288 ] = dequant_args [0 ]
251- self .edge_program_manager ._config_methods [ # pyre-ignore
252- get_config_method_name (self .param_prefix_name , "output" , i , "zp" )
289+ self .edge_program_manager ._config_methods [
290+ get_config_method_name (self .param_prefix_name , "output" , idx , "zp" )
253291 ] = dequant_args [1 ]
254292 self .edge_program_manager ._config_methods [
255- get_config_method_name (self .param_prefix_name , "output" , i , "quant_min" )
293+ get_config_method_name (
294+ self .param_prefix_name , "output" , idx , "quant_min"
295+ )
256296 ] = dequant_args [2 ]
257297 self .edge_program_manager ._config_methods [
258- get_config_method_name (self .param_prefix_name , "output" , i , "quant_max" )
298+ get_config_method_name (
299+ self .param_prefix_name , "output" , idx , "quant_max"
300+ )
259301 ] = dequant_args [3 ]
260302 self .edge_program_manager ._config_methods [
261- get_config_method_name (self .param_prefix_name , "output" , i , "dtype" )
303+ get_config_method_name (self .param_prefix_name , "output" , idx , "dtype" )
262304 ] = scalar_type_enum (dequant_args [4 ])
263305
306+ def edge_manager_update_quant_config_methods_all (self ):
307+ if self .edge_program_manager is not None :
308+ for idx , val in self .dequant_args .items ():
309+ self .edge_manager_update_quant_config_method (idx , val )
310+
311+ def call (self , graph_module : torch .fx .GraphModule ):
312+ for i in self .quantized_outputs_idx_list :
313+ exported_program = (
314+ self .edge_program_manager .exported_program ()
315+ if self .edge_program_manager is not None
316+ else self .exported_program
317+ )
318+ self .dequant_args [i ] = quantize_output (exported_program , i ) # noqa F841
319+ self .edge_manager_update_quant_config_method (i , self .dequant_args [i ])
320+
264321 return PassResult (graph_module , True )
0 commit comments