diff --git a/exir/passes/quantize_io_pass.py b/exir/passes/quantize_io_pass.py index 64b6c14d75e..095b07a1bf7 100644 --- a/exir/passes/quantize_io_pass.py +++ b/exir/passes/quantize_io_pass.py @@ -11,7 +11,7 @@ import torch -from executorch.exir import EdgeProgramManager +from executorch.exir import EdgeProgramManager, ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -39,11 +39,33 @@ def quantize_input( if len(target_placeholder.users) != 1: raise ValueError(f"Input {input_index} has more than one users") quantize = next(iter(target_placeholder.users)) + if quantize.target not in [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ]: + raise ValueError( + f"Input {input_index} is not used by a quantize op. It's used by {quantize.target}" + ) + if ( quantize.target - != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default ): - raise ValueError(f"Input {input_index} is not used by a quantize op") + replacement_op_dequant = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ) + replacement_op_quant = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ) + elif quantize.target == torch.ops.quantized_decomposed.quantize_per_tensor.default: + replacement_op_dequant = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default + ) + replacement_op_quant = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + else: + raise ValueError(f"Invalid quantize op: {quantize.target}") # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op need_requant = False @@ -83,7 +105,7 @@ def quantize_input( with exported_program.graph_module.graph.inserting_before(quantize): input_dequant = exported_program.graph_module.graph.call_function( - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + replacement_op_dequant, args=( target_placeholder, *quant_args, @@ -106,10 +128,8 @@ def quantize_input( logger.info(f"Modifying program to take quantized input at index {input_index}") logger.info(f"Quantization parameters: {quant_args}") - target_placeholder.meta["val"] = ( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( - target_placeholder.meta["val"], *quant_args - ) + target_placeholder.meta["val"] = replacement_op_quant( + target_placeholder.meta["val"], *quant_args ) quantize.replace_all_uses_with(quantize.args[0]) @@ -138,10 +158,10 @@ def quantize_output(exported_program, output_index): ) target_output = output_list[output_index] - if ( - target_output.target - != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default - ): + if target_output.target not in [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ]: raise ValueError("Output {output_index} is not a dequantize op") dequant = target_output @@ -185,6 +205,7 @@ def __init__( edge_program_manager: EdgeProgramManager, quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]], method_name: Optional[str] = None, + exported_program: Optional[ExportedProgram] = None, ): super().__init__() self.edge_program_manager = edge_program_manager @@ -196,31 +217,49 @@ def __init__( for idx in quantized_inputs_idx: self.quantized_inputs_idx_dict[idx] = None self.param_prefix_name = method_name + self.exported_program = exported_program + self.quant_args = {} - def call(self, graph_module: torch.fx.GraphModule): - for i, qparams in self.quantized_inputs_idx_dict.items(): - quant_args = quantize_input( - self.edge_program_manager.exported_program(), i, qparams - ) - + def edge_manager_update_quant_config_method(self, idx, quant_args): + if self.edge_program_manager is not None: if not self.edge_program_manager._config_methods: self.edge_program_manager._config_methods = {} self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "input", i, "scale") + get_config_method_name(self.param_prefix_name, "input", idx, "scale") ] = quant_args[0] - self.edge_program_manager._config_methods[ # pyre-ignore - get_config_method_name(self.param_prefix_name, "input", i, "zp") + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", idx, "zp") ] = quant_args[1] self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "input", i, "quant_min") + get_config_method_name( + self.param_prefix_name, "input", idx, "quant_min" + ) ] = quant_args[2] self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "input", i, "quant_max") + get_config_method_name( + self.param_prefix_name, "input", idx, "quant_max" + ) ] = quant_args[3] self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "input", i, "dtype") + get_config_method_name(self.param_prefix_name, "input", idx, "dtype") ] = scalar_type_enum(quant_args[4]) + + def edge_manager_update_quant_config_methods_all(self): + if self.edge_program_manager is not None: + for idx, val in self.quant_args.items(): + self.edge_manager_update_quant_config_method(idx, val) + + def call(self, graph_module: torch.fx.GraphModule): + for i, qparams in self.quantized_inputs_idx_dict.items(): + exported_program = ( + self.edge_program_manager.exported_program() + if self.edge_program_manager is not None + else self.exported_program + ) + self.quant_args[i] = quantize_input(exported_program, i, qparams) + self.edge_manager_update_quant_config_method(i, self.quant_args[i]) + return PassResult(graph_module, True) @@ -230,35 +269,53 @@ def __init__( edge_program_manager: EdgeProgramManager, quantized_outputs_idx_list: List[int], method_name: Optional[str] = None, + exported_program: Optional[ExportedProgram] = None, ): super().__init__() self.edge_program_manager = edge_program_manager self.quantized_outputs_idx_list = quantized_outputs_idx_list self.param_prefix_name = method_name + self.exported_program = exported_program + self.dequant_args = {} - def call(self, graph_module: torch.fx.GraphModule): - for i in self.quantized_outputs_idx_list: - dequant_args = quantize_output( - self.edge_program_manager.exported_program(), i - ) # noqa F841 - + def edge_manager_update_quant_config_method(self, idx, dequant_args): + if self.edge_program_manager is not None: if not self.edge_program_manager._config_methods: self.edge_program_manager._config_methods = {} self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "output", i, "scale") + get_config_method_name(self.param_prefix_name, "output", idx, "scale") ] = dequant_args[0] - self.edge_program_manager._config_methods[ # pyre-ignore - get_config_method_name(self.param_prefix_name, "output", i, "zp") + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", idx, "zp") ] = dequant_args[1] self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "output", i, "quant_min") + get_config_method_name( + self.param_prefix_name, "output", idx, "quant_min" + ) ] = dequant_args[2] self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "output", i, "quant_max") + get_config_method_name( + self.param_prefix_name, "output", idx, "quant_max" + ) ] = dequant_args[3] self.edge_program_manager._config_methods[ - get_config_method_name(self.param_prefix_name, "output", i, "dtype") + get_config_method_name(self.param_prefix_name, "output", idx, "dtype") ] = scalar_type_enum(dequant_args[4]) + def edge_manager_update_quant_config_methods_all(self): + if self.edge_program_manager is not None: + for idx, val in self.dequant_args.items(): + self.edge_manager_update_quant_config_method(idx, val) + + def call(self, graph_module: torch.fx.GraphModule): + for i in self.quantized_outputs_idx_list: + exported_program = ( + self.edge_program_manager.exported_program() + if self.edge_program_manager is not None + else self.exported_program + ) + self.dequant_args[i] = quantize_output(exported_program, i) # noqa F841 + self.edge_manager_update_quant_config_method(i, self.dequant_args[i]) + return PassResult(graph_module, True)