diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index 50003dac925..01fe2ee26a4 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -22,6 +22,7 @@ class RemoveCloneOpsTransform(ExportPass): clone_ops: Set[torch._ops.OpOverload] = { exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, } def __init__(self) -> None: @@ -34,12 +35,15 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: if n.target not in self.clone_ops: continue - to_be_remove = n + if self._is_non_identity_clone(n): + continue + + to_be_removed = n for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) if n.args[0].target in _DEQUANT_OPS: dequant_nodes += [n.args[0]] - graph_module.graph.erase_node(to_be_remove) + graph_module.graph.erase_node(to_be_removed) eliminate_dq_q(graph_module, dequant_nodes) @@ -48,3 +52,27 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph_module.recompile() dead_code_elimination_pass(graph_module) return PassResult(graph_module, True) + + def _is_non_identity_clone(self, node: torch.fx.Node) -> bool: + """Return True if clone has modified memory layout or dim order.""" + + # aten.clone: check for memory_format changes + if node.target == exir_ops.edge.aten.clone.default: + memory_format = node.kwargs.get("memory_format") + if memory_format in (None, torch.preserve_format): + return False + input_meta = node.args[0].meta + return "val" in input_meta and not input_meta["val"].is_contiguous( + memory_format=memory_format + ) + + # _clone_dim_order: check for dim_order changes + if node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default: + input_meta = node.args[0].meta + return ( + "val" in node.meta + and "val" in input_meta + and node.meta["val"].dim_order() != input_meta["val"].dim_order() + ) + + return False diff --git a/backends/transforms/test/test_remove_clone_ops.py b/backends/transforms/test/test_remove_clone_ops.py index 5d7a1ecd59f..d34c522baaa 100644 --- a/backends/transforms/test/test_remove_clone_ops.py +++ b/backends/transforms/test/test_remove_clone_ops.py @@ -8,13 +8,30 @@ import torch from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform +from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dim_order_utils import is_channel_last_dim_order +from executorch.exir.tests.test_memory_format_ops_pass_utils import ( + SimpleCloneChannelsLastModule, +) +from torch.export import export from torch.fx import GraphModule from torch.testing import FileCheck from torch.testing._internal.common_utils import TestCase class TestRemoveCloneOpsTransform(TestCase): + # Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag. + # _skip_dim_order=True tests aten.clone + # _skip_dim_order=False tests _clone_dim_order + CLONE_OP_CASES = [ + (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), + ( + False, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + ), + ] + def test_dq_clone_q_linear(self): """ Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern @@ -123,6 +140,58 @@ def forward(self, x): transformed_gm.code ) + def test_clone_non_identity_survives(self): + """Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform.""" + + for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), + ) + + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + + FileCheck().check_count(clone_op_str, 1, exactly=True).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) + + def test_clone_identity_removed(self): + """Verify identity clone ops are removed by RemoveCloneOpsTransform.""" + + for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: + model = SimpleCloneChannelsLastModule() + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) + + exported = export(model.eval(), (x,), strict=True) + before_epm = to_edge( + exported, + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), + ) + + FileCheck().check_count(clone_op_str, 1, exactly=True).run( + before_epm.exported_program().graph_module.code + ) + + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) + + FileCheck().check_not(clone_op_str).run( + updated_epm.exported_program().graph_module.code + ) + + expected = before_epm.exported_program().module()(x) + actual = updated_epm.exported_program().module()(x) + assert torch.allclose(actual, expected) + assert is_channel_last_dim_order(actual) + if __name__ == "__main__": unittest.main()