Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions backends/transforms/remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class RemoveCloneOpsTransform(ExportPass):

clone_ops: Set[torch._ops.OpOverload] = {
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

def __init__(self) -> None:
Expand All @@ -34,12 +35,18 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
if n.target not in self.clone_ops:
continue

to_be_remove = n
# Skip removal of clone ops that modify layout/dim order.
if self.aten_clone_is_non_identity(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UFMT formatter forces this split style

n
) or self._clone_dim_order_is_non_identity(n):
continue

to_be_removed = n
for user_n in list(n.users.keys()):
user_n.replace_input_with(n, n.args[0])
if n.args[0].target in _DEQUANT_OPS:
dequant_nodes += [n.args[0]]
graph_module.graph.erase_node(to_be_remove)
graph_module.graph.erase_node(to_be_removed)

eliminate_dq_q(graph_module, dequant_nodes)

Expand All @@ -48,3 +55,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)

def aten_clone_is_non_identity(self, node: torch.fx.Node) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's combine the two check functions, aten_clone_is_non_identity and _clone_dim_order_is_non_identity into a single function (maybe called _is_non_identity_clone). Under current scenario we will always use the funcs together and this func should be private.

"""Return True if aten.clone has modified memory format."""
if node.target != exir_ops.edge.aten.clone.default:
return False

memory_format = node.kwargs.get("memory_format")
if memory_format in (None, torch.preserve_format):
return False

input_meta = node.args[0].meta
return "val" in input_meta and not input_meta["val"].is_contiguous(
memory_format=memory_format
)

def _clone_dim_order_is_non_identity(self, node: torch.fx.Node) -> bool:
"""Return True if _clone_dim_order has modified dim order."""
if node.target != exir_ops.edge.dim_order_ops._clone_dim_order.default:
return False

input_meta = node.args[0].meta
return (
"val" in node.meta
and "val" in input_meta
and node.meta["val"].dim_order() != input_meta["val"].dim_order()
)
69 changes: 69 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

import torchvision
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -376,6 +377,74 @@ def call_operator(self, op, args, kwargs, meta):
self.assertTrue(is_contiguous_dim_order(actual))
self.assertTrue(is_contiguous_dim_order(expected))

def test_op_clone_replacement_channels_last_survives(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move test to test_remove_clone_ops.py

clone_op_cases = [
# Case testing aten.clone by setting _skip_dim_order to True
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
# Case testing _clone_dim_order by setting _skip_dim_order to False
(
False,
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
),
]

for skip_dim_order, clone_op_str in clone_op_cases:
model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported,
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_count(clone_op_str, 1, exactly=True).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)

def test_op_clone_without_transformation_removed(self):
clone_op_cases = [
# Case testing aten.clone by setting _skip_dim_order to True
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
# Case testing _clone_dim_order by setting _skip_dim_order to False
(
False,
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
),
]

for skip_dim_order, clone_op_str in clone_op_cases:
model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported,
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
)

FileCheck().check_count(clone_op_str, 1, exactly=True).run(
before_epm.exported_program().graph_module.code
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_not(clone_op_str).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)

def test_resnet18(self) -> None:
model = torchvision.models.resnet18()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
Expand Down
Loading