|
8 | 8 |
|
9 | 9 | import torch
|
10 | 10 | from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
|
| 11 | +from executorch.exir import EdgeCompileConfig, to_edge |
11 | 12 | from executorch.exir.dialects._ops import ops as exir_ops
|
| 13 | +from executorch.exir.dim_order_utils import is_channel_last_dim_order |
| 14 | +from executorch.exir.tests.test_memory_format_ops_pass_utils import ( |
| 15 | + SimpleCloneChannelsLastModule, |
| 16 | +) |
| 17 | +from torch.export import export |
12 | 18 | from torch.fx import GraphModule
|
13 | 19 | from torch.testing import FileCheck
|
14 | 20 | from torch.testing._internal.common_utils import TestCase
|
15 | 21 |
|
16 | 22 |
|
17 | 23 | class TestRemoveCloneOpsTransform(TestCase):
|
| 24 | + # Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag. |
| 25 | + # _skip_dim_order=True tests aten.clone |
| 26 | + # _skip_dim_order=False tests _clone_dim_order |
| 27 | + CLONE_OP_CASES = [ |
| 28 | + (True, "executorch_exir_dialects_edge__ops_aten_clone_default"), |
| 29 | + ( |
| 30 | + False, |
| 31 | + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", |
| 32 | + ), |
| 33 | + ] |
| 34 | + |
18 | 35 | def test_dq_clone_q_linear(self):
|
19 | 36 | """
|
20 | 37 | Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
|
@@ -123,6 +140,58 @@ def forward(self, x):
|
123 | 140 | transformed_gm.code
|
124 | 141 | )
|
125 | 142 |
|
| 143 | + def test_clone_non_identity_survives(self): |
| 144 | + """Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform.""" |
| 145 | + |
| 146 | + for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: |
| 147 | + model = SimpleCloneChannelsLastModule() |
| 148 | + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) |
| 149 | + |
| 150 | + exported = export(model.eval(), (x,), strict=True) |
| 151 | + before_epm = to_edge( |
| 152 | + exported, |
| 153 | + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), |
| 154 | + ) |
| 155 | + |
| 156 | + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) |
| 157 | + |
| 158 | + FileCheck().check_count(clone_op_str, 1, exactly=True).run( |
| 159 | + updated_epm.exported_program().graph_module.code |
| 160 | + ) |
| 161 | + |
| 162 | + expected = before_epm.exported_program().module()(x) |
| 163 | + actual = updated_epm.exported_program().module()(x) |
| 164 | + assert torch.allclose(actual, expected) |
| 165 | + assert is_channel_last_dim_order(actual) |
| 166 | + |
| 167 | + def test_clone_identity_removed(self): |
| 168 | + """Verify identity clone ops are removed by RemoveCloneOpsTransform.""" |
| 169 | + |
| 170 | + for skip_dim_order, clone_op_str in self.CLONE_OP_CASES: |
| 171 | + model = SimpleCloneChannelsLastModule() |
| 172 | + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) |
| 173 | + |
| 174 | + exported = export(model.eval(), (x,), strict=True) |
| 175 | + before_epm = to_edge( |
| 176 | + exported, |
| 177 | + compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order), |
| 178 | + ) |
| 179 | + |
| 180 | + FileCheck().check_count(clone_op_str, 1, exactly=True).run( |
| 181 | + before_epm.exported_program().graph_module.code |
| 182 | + ) |
| 183 | + |
| 184 | + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) |
| 185 | + |
| 186 | + FileCheck().check_not(clone_op_str).run( |
| 187 | + updated_epm.exported_program().graph_module.code |
| 188 | + ) |
| 189 | + |
| 190 | + expected = before_epm.exported_program().module()(x) |
| 191 | + actual = updated_epm.exported_program().module()(x) |
| 192 | + assert torch.allclose(actual, expected) |
| 193 | + assert is_channel_last_dim_order(actual) |
| 194 | + |
126 | 195 |
|
127 | 196 | if __name__ == "__main__":
|
128 | 197 | unittest.main()
|
0 commit comments