Skip to content

Commit dbac09c

Browse files
[EXIR] Update RemoveCloneOpsTransform to be dim order aware (#12976)
### Summary This is PR 3 of 3 implementing a dim order aware clone op. This PR updates the clone removal pass to retain layout changing `aten.clone` and `_clone_dim_order` ops and remove no-op clones, ensuring layout/dim order is preserved through export. Related PRs: - PR 1: [#12974](#12974) - Add `_clone_dim_order` portable kernel - PR 2: [#13735](#13735) - Register `_clone_dim_order` op and map `aten.clone` Fixes #12645 ### Test plan Added tests to verify: - Clones that change layout are preserved. - Clones with unchanged layout (identity ops) are removed. All tests pass via: python -m unittest exir.tests.test_memory_format_ops_pass python -m unittest backends.transforms.test.test_remove_clone_ops --------- Co-authored-by: Gasoonjia <[email protected]>
1 parent c3e4ed9 commit dbac09c

File tree

2 files changed

+99
-2
lines changed

2 files changed

+99
-2
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class RemoveCloneOpsTransform(ExportPass):
2222

2323
clone_ops: Set[torch._ops.OpOverload] = {
2424
exir_ops.edge.aten.clone.default,
25+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
2526
}
2627

2728
def __init__(self) -> None:
@@ -34,12 +35,15 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3435
if n.target not in self.clone_ops:
3536
continue
3637

37-
to_be_remove = n
38+
if self._is_non_identity_clone(n):
39+
continue
40+
41+
to_be_removed = n
3842
for user_n in list(n.users.keys()):
3943
user_n.replace_input_with(n, n.args[0])
4044
if n.args[0].target in _DEQUANT_OPS:
4145
dequant_nodes += [n.args[0]]
42-
graph_module.graph.erase_node(to_be_remove)
46+
graph_module.graph.erase_node(to_be_removed)
4347

4448
eliminate_dq_q(graph_module, dequant_nodes)
4549

@@ -48,3 +52,27 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4852
graph_module.recompile()
4953
dead_code_elimination_pass(graph_module)
5054
return PassResult(graph_module, True)
55+
56+
def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
57+
"""Return True if clone has modified memory layout or dim order."""
58+
59+
# aten.clone: check for memory_format changes
60+
if node.target == exir_ops.edge.aten.clone.default:
61+
memory_format = node.kwargs.get("memory_format")
62+
if memory_format in (None, torch.preserve_format):
63+
return False
64+
input_meta = node.args[0].meta
65+
return "val" in input_meta and not input_meta["val"].is_contiguous(
66+
memory_format=memory_format
67+
)
68+
69+
# _clone_dim_order: check for dim_order changes
70+
if node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default:
71+
input_meta = node.args[0].meta
72+
return (
73+
"val" in node.meta
74+
and "val" in input_meta
75+
and node.meta["val"].dim_order() != input_meta["val"].dim_order()
76+
)
77+
78+
return False

backends/transforms/test/test_remove_clone_ops.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,30 @@
88

99
import torch
1010
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
11+
from executorch.exir import EdgeCompileConfig, to_edge
1112
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dim_order_utils import is_channel_last_dim_order
14+
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
15+
SimpleCloneChannelsLastModule,
16+
)
17+
from torch.export import export
1218
from torch.fx import GraphModule
1319
from torch.testing import FileCheck
1420
from torch.testing._internal.common_utils import TestCase
1521

1622

1723
class TestRemoveCloneOpsTransform(TestCase):
24+
# Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag.
25+
# _skip_dim_order=True tests aten.clone
26+
# _skip_dim_order=False tests _clone_dim_order
27+
CLONE_OP_CASES = [
28+
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
29+
(
30+
False,
31+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
32+
),
33+
]
34+
1835
def test_dq_clone_q_linear(self):
1936
"""
2037
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
@@ -123,6 +140,58 @@ def forward(self, x):
123140
transformed_gm.code
124141
)
125142

143+
def test_clone_non_identity_survives(self):
144+
"""Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform."""
145+
146+
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
147+
model = SimpleCloneChannelsLastModule()
148+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
149+
150+
exported = export(model.eval(), (x,), strict=True)
151+
before_epm = to_edge(
152+
exported,
153+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
154+
)
155+
156+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
157+
158+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
159+
updated_epm.exported_program().graph_module.code
160+
)
161+
162+
expected = before_epm.exported_program().module()(x)
163+
actual = updated_epm.exported_program().module()(x)
164+
assert torch.allclose(actual, expected)
165+
assert is_channel_last_dim_order(actual)
166+
167+
def test_clone_identity_removed(self):
168+
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""
169+
170+
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
171+
model = SimpleCloneChannelsLastModule()
172+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
173+
174+
exported = export(model.eval(), (x,), strict=True)
175+
before_epm = to_edge(
176+
exported,
177+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
178+
)
179+
180+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
181+
before_epm.exported_program().graph_module.code
182+
)
183+
184+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
185+
186+
FileCheck().check_not(clone_op_str).run(
187+
updated_epm.exported_program().graph_module.code
188+
)
189+
190+
expected = before_epm.exported_program().module()(x)
191+
actual = updated_epm.exported_program().module()(x)
192+
assert torch.allclose(actual, expected)
193+
assert is_channel_last_dim_order(actual)
194+
126195

127196
if __name__ == "__main__":
128197
unittest.main()

0 commit comments

Comments
 (0)