Skip to content

Commit e133898

Browse files
committed
Move clone tests from test_memory_format_ops_pass to test_remove_clone_ops
1 parent b8485bc commit e133898

File tree

2 files changed

+69
-69
lines changed

2 files changed

+69
-69
lines changed

backends/transforms/test/test_remove_clone_ops.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,30 @@
88

99
import torch
1010
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
11+
from executorch.exir import EdgeCompileConfig, to_edge
1112
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dim_order_utils import is_channel_last_dim_order
14+
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
15+
SimpleCloneChannelsLastModule,
16+
)
17+
from torch.export import export
1218
from torch.fx import GraphModule
1319
from torch.testing import FileCheck
1420
from torch.testing._internal.common_utils import TestCase
1521

1622

1723
class TestRemoveCloneOpsTransform(TestCase):
24+
# Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag.
25+
# _skip_dim_order=True tests aten.clone
26+
# _skip_dim_order=False tests _clone_dim_order.
27+
CLONE_OP_CASES = [
28+
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
29+
(
30+
False,
31+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
32+
),
33+
]
34+
1835
def test_dq_clone_q_linear(self):
1936
"""
2037
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
@@ -123,6 +140,58 @@ def forward(self, x):
123140
transformed_gm.code
124141
)
125142

143+
def test_clone_channels_last_survives(self):
144+
"""Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform."""
145+
146+
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
147+
model = SimpleCloneChannelsLastModule()
148+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
149+
150+
exported = export(model.eval(), (x,), strict=True)
151+
before_epm = to_edge(
152+
exported,
153+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
154+
)
155+
156+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
157+
158+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
159+
updated_epm.exported_program().graph_module.code
160+
)
161+
162+
expected = before_epm.exported_program().module()(x)
163+
actual = updated_epm.exported_program().module()(x)
164+
assert torch.allclose(actual, expected)
165+
assert is_channel_last_dim_order(actual)
166+
167+
def test_clone_identity_removed(self):
168+
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""
169+
170+
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
171+
model = SimpleCloneChannelsLastModule()
172+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
173+
174+
exported = export(model.eval(), (x,), strict=True)
175+
before_epm = to_edge(
176+
exported,
177+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
178+
)
179+
180+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
181+
before_epm.exported_program().graph_module.code
182+
)
183+
184+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
185+
186+
FileCheck().check_not(clone_op_str).run(
187+
updated_epm.exported_program().graph_module.code
188+
)
189+
190+
expected = before_epm.exported_program().module()(x)
191+
actual = updated_epm.exported_program().module()(x)
192+
assert torch.allclose(actual, expected)
193+
assert is_channel_last_dim_order(actual)
194+
126195

127196
if __name__ == "__main__":
128197
unittest.main()

exir/tests/test_memory_format_ops_pass.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313

1414
import torchvision
15-
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1615
from executorch.exir import EdgeCompileConfig, to_edge
1716
from executorch.exir.dialects._ops import ops as exir_ops
1817
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -377,74 +376,6 @@ def call_operator(self, op, args, kwargs, meta):
377376
self.assertTrue(is_contiguous_dim_order(actual))
378377
self.assertTrue(is_contiguous_dim_order(expected))
379378

380-
def test_op_clone_replacement_channels_last_survives(self):
381-
clone_op_cases = [
382-
# Case testing aten.clone by setting _skip_dim_order to True
383-
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
384-
# Case testing _clone_dim_order by setting _skip_dim_order to False
385-
(
386-
False,
387-
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
388-
),
389-
]
390-
391-
for skip_dim_order, clone_op_str in clone_op_cases:
392-
model = SimpleCloneChannelsLastModule()
393-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
394-
395-
exported = export(model.eval(), (x,), strict=True)
396-
before_epm = to_edge(
397-
exported,
398-
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
399-
)
400-
401-
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
402-
403-
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
404-
updated_epm.exported_program().graph_module.code
405-
)
406-
407-
expected = before_epm.exported_program().module()(x)
408-
actual = updated_epm.exported_program().module()(x)
409-
assert torch.allclose(actual, expected)
410-
assert is_channel_last_dim_order(actual)
411-
412-
def test_op_clone_without_transformation_removed(self):
413-
clone_op_cases = [
414-
# Case testing aten.clone by setting _skip_dim_order to True
415-
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
416-
# Case testing _clone_dim_order by setting _skip_dim_order to False
417-
(
418-
False,
419-
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
420-
),
421-
]
422-
423-
for skip_dim_order, clone_op_str in clone_op_cases:
424-
model = SimpleCloneChannelsLastModule()
425-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
426-
427-
exported = export(model.eval(), (x,), strict=True)
428-
before_epm = to_edge(
429-
exported,
430-
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
431-
)
432-
433-
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
434-
before_epm.exported_program().graph_module.code
435-
)
436-
437-
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
438-
439-
FileCheck().check_not(clone_op_str).run(
440-
updated_epm.exported_program().graph_module.code
441-
)
442-
443-
expected = before_epm.exported_program().module()(x)
444-
actual = updated_epm.exported_program().module()(x)
445-
assert torch.allclose(actual, expected)
446-
assert is_channel_last_dim_order(actual)
447-
448379
def test_resnet18(self) -> None:
449380
model = torchvision.models.resnet18()
450381
MemoryFormatOpsPassTestUtils.memory_format_test_runner(

0 commit comments

Comments
 (0)