Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions backends/transforms/remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RemoveCloneOpsTransform(ExportPass):

clone_ops: Set[torch._ops.OpOverload] = {
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

def __init__(self) -> None:
Expand All @@ -34,12 +35,15 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
if n.target not in self.clone_ops:
continue

to_be_remove = n
if self._is_non_identity_clone(n):
continue

to_be_removed = n
for user_n in list(n.users.keys()):
user_n.replace_input_with(n, n.args[0])
if n.args[0].target in _DEQUANT_OPS:
dequant_nodes += [n.args[0]]
graph_module.graph.erase_node(to_be_remove)
graph_module.graph.erase_node(to_be_removed)

eliminate_dq_q(graph_module, dequant_nodes)

Expand All @@ -48,3 +52,27 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)

def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
"""Return True if clone has modified memory layout or dim order."""

# aten.clone: check for memory_format changes
if node.target == exir_ops.edge.aten.clone.default:
memory_format = node.kwargs.get("memory_format")
if memory_format in (None, torch.preserve_format):
return False
input_meta = node.args[0].meta
return "val" in input_meta and not input_meta["val"].is_contiguous(
memory_format=memory_format
)

# _clone_dim_order: check for dim_order changes
if node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default:
input_meta = node.args[0].meta
return (
"val" in node.meta
and "val" in input_meta
and node.meta["val"].dim_order() != input_meta["val"].dim_order()
)

return False
69 changes: 69 additions & 0 deletions backends/transforms/test/test_remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,30 @@

import torch
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dim_order_utils import is_channel_last_dim_order
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
SimpleCloneChannelsLastModule,
)
from torch.export import export
from torch.fx import GraphModule
from torch.testing import FileCheck
from torch.testing._internal.common_utils import TestCase


class TestRemoveCloneOpsTransform(TestCase):
# Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag.
# _skip_dim_order=True tests aten.clone
# _skip_dim_order=False tests _clone_dim_order.
CLONE_OP_CASES = [
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
(
False,
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
),
]

def test_dq_clone_q_linear(self):
"""
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
Expand Down Expand Up @@ -123,6 +140,58 @@ def forward(self, x):
transformed_gm.code
)

def test_clone_channels_last_survives(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe should called it test_clone_non_identity_survives? cuz it survives because of mutating memory_format / dim_order, rather than cloning a channels_last tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, that makes more sense

"""Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform."""

for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported,
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_count(clone_op_str, 1, exactly=True).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)

def test_clone_identity_removed(self):
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""

for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported,
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
)

FileCheck().check_count(clone_op_str, 1, exactly=True).run(
before_epm.exported_program().graph_module.code
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_not(clone_op_str).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)


if __name__ == "__main__":
unittest.main()
Loading