From c4a8da2d1a0b64ce74e2b13b7649e4776a12d452 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Thu, 4 Dec 2025 11:51:40 -0800 Subject: [PATCH] Support eliminate_quant_dequant_pairs flag (#16029) Summary: This change adds support for the `eliminate_quant_dequant_pairs` flag in the remove clone ops transform. This flag allows users to control whether quantization-dequantization pairs should be eliminated during the clone removal optimization pass. We need this to control the functionality of this pass in the subsequent diffs. Reviewed By: hsharma35 Differential Revision: D88092217 --- .../aot/tests/test_remove_ops_passes.py | 4 +-- backends/transforms/remove_clone_ops.py | 26 ++++++++++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 158ec73cf27..c957eb04b87 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -21,7 +21,7 @@ RemoveAliasCopyOpPass, RemoveBranchedQuantDequant, RemoveCatFromSliceCopyPass, - RemoveCloneOpPass, + RemoveCloneOpsTransformImported, RemoveContiguousOpPass, RemoveDetachCopyPass, RemoveNopAddOpPass, @@ -241,7 +241,7 @@ def test_remove_clone(self) -> None: clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,)) builder.output([clone]) original = builder.get_graph_module() - p = RemoveCloneOpPass() + p = RemoveCloneOpsTransformImported() graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, torch.ops.aten.clone.default), 0 diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index 79b93af8beb..07cc3e9efb1 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -25,12 +25,18 @@ class RemoveCloneOpsTransform(ExportPass): exir_ops.edge.dim_order_ops._clone_dim_order.default, } - def __init__(self, preserve_input_output_copies: bool = False) -> None: + def __init__( + self, + preserve_input_output_copies: bool = False, + eliminate_quant_dequant_pairs: bool = True, + ) -> None: super().__init__() self._preserve_input_output_copies = preserve_input_output_copies + self._eliminate_quant_dequant_pairs = eliminate_quant_dequant_pairs - def _remove(self, graph_module: torch.fx.GraphModule) -> None: + def _remove(self, graph_module: torch.fx.GraphModule) -> bool: dequant_nodes = [] + modified = False for n in graph_module.graph.nodes: if n.target not in self.clone_ops: @@ -44,6 +50,7 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: if self._is_input_output_copy(n) and self._preserve_input_output_copies: continue + modified = True to_be_removed = n for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) @@ -51,13 +58,18 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: dequant_nodes += [n.args[0]] graph_module.graph.erase_node(to_be_removed) - eliminate_dq_q(graph_module, dequant_nodes) + if self._eliminate_quant_dequant_pairs: + eliminate_dq_q(graph_module, dequant_nodes) + + return modified def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self._remove(graph_module) - graph_module.recompile() - dead_code_elimination_pass(graph_module) - return PassResult(graph_module, True) + if self._remove(graph_module): + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) + else: + return PassResult(graph_module, False) def _is_non_identity_clone(self, node: torch.fx.Node) -> bool: """Return True if clone has modified memory layout or dim order."""