-
Notifications
You must be signed in to change notification settings - Fork 698
[EXIR] Update RemoveCloneOpsTransform to be dim order aware #12976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
f2f2932
ad74bdf
ffd1549
e14a700
b8485bc
e133898
17f2e6c
0cbb5e0
4b68e11
21a516a
f623103
ffa3101
15ff154
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ class RemoveCloneOpsTransform(ExportPass): | |
|
|
||
| clone_ops: Set[torch._ops.OpOverload] = { | ||
| exir_ops.edge.aten.clone.default, | ||
| exir_ops.edge.dim_order_ops._clone_dim_order.default, | ||
| } | ||
|
|
||
| def __init__(self) -> None: | ||
|
|
@@ -34,12 +35,18 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: | |
| if n.target not in self.clone_ops: | ||
| continue | ||
|
|
||
| to_be_remove = n | ||
| # Skip removal of clone ops that modify layout/dim order. | ||
| if self.aten_clone_is_non_identity( | ||
| n | ||
| ) or self._clone_dim_order_is_non_identity(n): | ||
| continue | ||
|
|
||
| to_be_removed = n | ||
| for user_n in list(n.users.keys()): | ||
| user_n.replace_input_with(n, n.args[0]) | ||
| if n.args[0].target in _DEQUANT_OPS: | ||
| dequant_nodes += [n.args[0]] | ||
| graph_module.graph.erase_node(to_be_remove) | ||
| graph_module.graph.erase_node(to_be_removed) | ||
|
|
||
| eliminate_dq_q(graph_module, dequant_nodes) | ||
|
|
||
|
|
@@ -48,3 +55,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | |
| graph_module.recompile() | ||
| dead_code_elimination_pass(graph_module) | ||
| return PassResult(graph_module, True) | ||
|
|
||
| def aten_clone_is_non_identity(self, node: torch.fx.Node) -> bool: | ||
|
||
| """Return True if aten.clone has modified memory format.""" | ||
| if node.target != exir_ops.edge.aten.clone.default: | ||
| return False | ||
|
|
||
| memory_format = node.kwargs.get("memory_format") | ||
| if memory_format in (None, torch.preserve_format): | ||
| return False | ||
|
|
||
| input_meta = node.args[0].meta | ||
| return "val" in input_meta and not input_meta["val"].is_contiguous( | ||
| memory_format=memory_format | ||
| ) | ||
|
|
||
| def _clone_dim_order_is_non_identity(self, node: torch.fx.Node) -> bool: | ||
| """Return True if _clone_dim_order has modified dim order.""" | ||
| if node.target != exir_ops.edge.dim_order_ops._clone_dim_order.default: | ||
| return False | ||
|
|
||
| input_meta = node.args[0].meta | ||
| return ( | ||
| "val" in node.meta | ||
| and "val" in input_meta | ||
| and node.meta["val"].dim_order() != input_meta["val"].dim_order() | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| import torch | ||
|
|
||
| import torchvision | ||
| from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform | ||
| from executorch.exir import EdgeCompileConfig, to_edge | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.dialects.edge._ops import EdgeOpOverload | ||
|
|
@@ -376,6 +377,74 @@ def call_operator(self, op, args, kwargs, meta): | |
| self.assertTrue(is_contiguous_dim_order(actual)) | ||
| self.assertTrue(is_contiguous_dim_order(expected)) | ||
|
|
||
| def test_op_clone_replacement_channels_last_survives(self): | ||
|
||
| clone_op_cases = [ | ||
| # Case testing aten.clone by setting _skip_dim_order to True | ||
| (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), | ||
| # Case testing _clone_dim_order by setting _skip_dim_order to False | ||
| ( | ||
| False, | ||
| "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", | ||
| ), | ||
| ] | ||
|
|
||
| for skip_dim_order, clone_op_str in clone_op_cases: | ||
| model = SimpleCloneChannelsLastModule() | ||
| x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) | ||
|
|
||
| exported = export(model.eval(), (x,), strict=True) | ||
| before_epm = to_edge( | ||
| exported, | ||
| compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), | ||
| ) | ||
|
|
||
| updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) | ||
|
|
||
| FileCheck().check_count(clone_op_str, 1, exactly=True).run( | ||
| updated_epm.exported_program().graph_module.code | ||
| ) | ||
|
|
||
| expected = before_epm.exported_program().module()(x) | ||
| actual = updated_epm.exported_program().module()(x) | ||
| assert torch.allclose(actual, expected) | ||
| assert is_channel_last_dim_order(actual) | ||
|
|
||
| def test_op_clone_without_transformation_removed(self): | ||
| clone_op_cases = [ | ||
| # Case testing aten.clone by setting _skip_dim_order to True | ||
| (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), | ||
| # Case testing _clone_dim_order by setting _skip_dim_order to False | ||
| ( | ||
| False, | ||
| "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", | ||
| ), | ||
| ] | ||
|
|
||
| for skip_dim_order, clone_op_str in clone_op_cases: | ||
| model = SimpleCloneChannelsLastModule() | ||
| x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) | ||
|
|
||
| exported = export(model.eval(), (x,), strict=True) | ||
| before_epm = to_edge( | ||
| exported, | ||
| compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), | ||
| ) | ||
|
|
||
| FileCheck().check_count(clone_op_str, 1, exactly=True).run( | ||
| before_epm.exported_program().graph_module.code | ||
| ) | ||
|
|
||
| updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) | ||
|
|
||
| FileCheck().check_not(clone_op_str).run( | ||
| updated_epm.exported_program().graph_module.code | ||
| ) | ||
|
|
||
| expected = before_epm.exported_program().module()(x) | ||
| actual = updated_epm.exported_program().module()(x) | ||
| assert torch.allclose(actual, expected) | ||
| assert is_channel_last_dim_order(actual) | ||
|
|
||
| def test_resnet18(self) -> None: | ||
| model = torchvision.models.resnet18() | ||
| MemoryFormatOpsPassTestUtils.memory_format_test_runner( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UFMT formatter forces this split style