Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions backends/transforms/remove_clone_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,24 @@

def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Remove clone op nodes and replace uses with parent node.
Remove clone op nodes that have the same dim_order as their input, and replace their uses with the input node.
"""
clone_op = exir_ops.edge.aten.clone.default
clone_dim_order_op = exir_ops.edge.dim_order_ops._clone_dim_order.default

for node in graph.nodes:
if node.op == "call_function" and node.target == clone_op:
if node.op != "call_function":
continue

# Identify clone_dim_order ops with unchanged memory layout.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are supporting aten.clone elimination through this pass then we should similarly check memory_format arg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! I added the check for aten.clone and updated the tests. I'll refactor/simplify the test cases if needed once we land the AOT PR since it includes its own tests.

unchanged_layout_clone = (
node.target == clone_dim_order_op
and "val" in node.meta
and "val" in node.args[0].meta
and node.meta["val"].dim_order() == node.args[0].meta["val"].dim_order()
)

if node.target == clone_op or unchanged_layout_clone:
with graph.inserting_after(node):
node.replace_all_uses_with(node.args[0])

Expand Down
53 changes: 53 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

import torchvision
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -324,6 +325,58 @@ def call_operator(self, op, args, kwargs, meta):
self.assertTrue(is_contiguous_dim_order(actual))
self.assertTrue(is_contiguous_dim_order(expected))

def test_op_clone_replacement_channels_last_survives(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move test to test_remove_clone_ops.py

_clone_dim_order_op_str = (
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
)

model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)

def test_op_clone_without_transformation_removed(self):
_clone_dim_order_op_str = (
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
)

model = SimpleCloneChannelsLastModule()
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)

exported = export(model.eval(), (x,), strict=True)
before_epm = to_edge(
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
)

FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run(
before_epm.exported_program().graph_module.code
)

updated_epm = before_epm.transform([RemoveCloneOpsTransform()])

FileCheck().check_not(_clone_dim_order_op_str).run(
updated_epm.exported_program().graph_module.code
)

expected = before_epm.exported_program().module()(x)
actual = updated_epm.exported_program().module()(x)
assert torch.allclose(actual, expected)
assert is_channel_last_dim_order(actual)

def test_resnet18(self) -> None:
model = torchvision.models.resnet18()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
Expand Down