Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
RemoveAliasCopyOpPass,
RemoveBranchedQuantDequant,
RemoveCatFromSliceCopyPass,
RemoveCloneOpPass,
RemoveCloneOpsTransformImported,
RemoveContiguousOpPass,
RemoveDetachCopyPass,
RemoveNopAddOpPass,
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_remove_clone(self) -> None:
clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,))
builder.output([clone])
original = builder.get_graph_module()
p = RemoveCloneOpPass()
p = RemoveCloneOpsTransformImported()
graph_after_passes = cast(PassResult, p(original)).graph_module
self.assertEqual(
count_node(graph_after_passes, torch.ops.aten.clone.default), 0
Expand Down
26 changes: 19 additions & 7 deletions backends/transforms/remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@ class RemoveCloneOpsTransform(ExportPass):
exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

def __init__(self, preserve_input_output_copies: bool = False) -> None:
def __init__(
self,
preserve_input_output_copies: bool = False,
eliminate_quant_dequant_pairs: bool = True,
) -> None:
super().__init__()
self._preserve_input_output_copies = preserve_input_output_copies
self._eliminate_quant_dequant_pairs = eliminate_quant_dequant_pairs

def _remove(self, graph_module: torch.fx.GraphModule) -> None:
def _remove(self, graph_module: torch.fx.GraphModule) -> bool:
dequant_nodes = []
modified = False

for n in graph_module.graph.nodes:
if n.target not in self.clone_ops:
Expand All @@ -44,20 +50,26 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
if self._is_input_output_copy(n) and self._preserve_input_output_copies:
continue

modified = True
to_be_removed = n
for user_n in list(n.users.keys()):
user_n.replace_input_with(n, n.args[0])
if n.args[0].target in _DEQUANT_OPS:
dequant_nodes += [n.args[0]]
graph_module.graph.erase_node(to_be_removed)

eliminate_dq_q(graph_module, dequant_nodes)
if self._eliminate_quant_dequant_pairs:
eliminate_dq_q(graph_module, dequant_nodes)

return modified

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self._remove(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)
if self._remove(graph_module):
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)
else:
return PassResult(graph_module, False)

def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
"""Return True if clone has modified memory layout or dim order."""
Expand Down
Loading