-
Notifications
You must be signed in to change notification settings - Fork 698
[EXIR] Update RemoveCloneOpsTransform to be dim order aware #12976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Gasoonjia
merged 13 commits into
pytorch:main
from
keyprocedure:add-dim-order-clone-transform
Sep 9, 2025
Merged
Changes from 2 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f2f2932
Update clone removal transform to be dim order aware; add tests
keyprocedure ad74bdf
Remove explicit MemoryFormatOpsPass transform from clone_dim_order tests
keyprocedure ffd1549
Add aten.clone memory_format check in RemoveCloneOpsTransform
keyprocedure e14a700
Merge branch 'main' into add-dim-order-clone-transform
keyprocedure b8485bc
Refactor clone identity check into _is_non_identity_clone
keyprocedure e133898
Move clone tests from test_memory_format_ops_pass to test_remove_clon…
keyprocedure 17f2e6c
Change clone test name to test_clone_non_identity_survives
keyprocedure 0cbb5e0
Merge branch 'main' into add-dim-order-clone-transform
Gasoonjia 4b68e11
Add test_remove_clone_ops to pytest.ini config
keyprocedure 21a516a
Merge branch 'main' into add-dim-order-clone-transform
Gasoonjia f623103
Merge branch 'main' into add-dim-order-clone-transform
keyprocedure ffa3101
Format pytest.ini
keyprocedure 15ff154
Merge branch 'main' into add-dim-order-clone-transform
Gasoonjia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| import torch | ||
|
|
||
| import torchvision | ||
| from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform | ||
| from executorch.exir import EdgeCompileConfig, to_edge | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.dialects.edge._ops import EdgeOpOverload | ||
|
|
@@ -324,6 +325,58 @@ def call_operator(self, op, args, kwargs, meta): | |
| self.assertTrue(is_contiguous_dim_order(actual)) | ||
| self.assertTrue(is_contiguous_dim_order(expected)) | ||
|
|
||
| def test_op_clone_replacement_channels_last_survives(self): | ||
|
||
| _clone_dim_order_op_str = ( | ||
| "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" | ||
| ) | ||
|
|
||
| model = SimpleCloneChannelsLastModule() | ||
| x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) | ||
|
|
||
| exported = export(model.eval(), (x,), strict=True) | ||
| before_epm = to_edge( | ||
| exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) | ||
| ) | ||
|
|
||
| updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) | ||
|
|
||
| FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( | ||
| updated_epm.exported_program().graph_module.code | ||
| ) | ||
|
|
||
| expected = before_epm.exported_program().module()(x) | ||
| actual = updated_epm.exported_program().module()(x) | ||
| assert torch.allclose(actual, expected) | ||
| assert is_channel_last_dim_order(actual) | ||
|
|
||
| def test_op_clone_without_transformation_removed(self): | ||
| _clone_dim_order_op_str = ( | ||
| "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" | ||
| ) | ||
|
|
||
| model = SimpleCloneChannelsLastModule() | ||
| x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) | ||
|
|
||
| exported = export(model.eval(), (x,), strict=True) | ||
| before_epm = to_edge( | ||
| exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) | ||
| ) | ||
|
|
||
| FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( | ||
| before_epm.exported_program().graph_module.code | ||
| ) | ||
|
|
||
| updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) | ||
|
|
||
| FileCheck().check_not(_clone_dim_order_op_str).run( | ||
| updated_epm.exported_program().graph_module.code | ||
| ) | ||
|
|
||
| expected = before_epm.exported_program().module()(x) | ||
| actual = updated_epm.exported_program().module()(x) | ||
| assert torch.allclose(actual, expected) | ||
| assert is_channel_last_dim_order(actual) | ||
|
|
||
| def test_resnet18(self) -> None: | ||
| model = torchvision.models.resnet18() | ||
| MemoryFormatOpsPassTestUtils.memory_format_test_runner( | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are supporting
aten.cloneelimination through this pass then we should similarly checkmemory_formatarg.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point! I added the check for
aten.cloneand updated the tests. I'll refactor/simplify the test cases if needed once we land the AOT PR since it includes its own tests.