Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions backends/transforms/remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RemoveCloneOpsTransform(ExportPass):

clone_ops: Set[torch._ops.OpOverload] = {
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

def __init__(self) -> None:
Expand All @@ -34,12 +35,15 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
if n.target not in self.clone_ops:
continue

to_be_remove = n
if self._is_non_identity_clone(n):
continue

to_be_removed = n
for user_n in list(n.users.keys()):
user_n.replace_input_with(n, n.args[0])
if n.args[0].target in _DEQUANT_OPS:
dequant_nodes += [n.args[0]]
graph_module.graph.erase_node(to_be_remove)
graph_module.graph.erase_node(to_be_removed)

eliminate_dq_q(graph_module, dequant_nodes)

Expand All @@ -48,3 +52,27 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)

def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
"""Return True if clone has modified memory layout or dim order."""

# aten.clone: check for memory_format changes
if node.target == exir_ops.edge.aten.clone.default:
memory_format = node.kwargs.get("memory_format")
if memory_format in (None, torch.preserve_format):
return False
input_meta = node.args[0].meta
return "val" in input_meta and not input_meta["val"].is_contiguous(
memory_format=memory_format
)

# _clone_dim_order: check for dim_order changes
if node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default:
input_meta = node.args[0].meta
return (
"val" in node.meta
and "val" in input_meta
and node.meta["val"].dim_order() != input_meta["val"].dim_order()
)

return False
69 changes: 69 additions & 0 deletions backends/transforms/test/test_remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,30 @@

import torch
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dim_order_utils import is_channel_last_dim_order
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
SimpleCloneChannelsLastModule,
)
from torch.export import export
from torch.fx import GraphModule
from torch.testing import FileCheck
from torch.testing._internal.common_utils import TestCase


class TestRemoveCloneOpsTransform(TestCase):
# Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag.
# _skip_dim_order=True tests aten.clone
# _skip_dim_order=False tests _clone_dim_order
CLONE_OP_CASES = [
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
(
False,
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
),
]

def test_dq_clone_q_linear(self):
"""
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
Expand Down Expand Up @@ -123,6 +140,58 @@ def forward(self, x):
transformed_gm.code
)

def test_clone_non_identity_survives(self):
"""Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform."""

for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported,
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_count(clone_op_str, 1, exactly=True).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)

def test_clone_identity_removed(self):
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""

for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported,
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
)

FileCheck().check_count(clone_op_str, 1, exactly=True).run(
before_epm.exported_program().graph_module.code
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_not(clone_op_str).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)


if __name__ == "__main__":
unittest.main()
Loading