Skip to content

Commit cff39c9

Browse files
committed
Add graph level and end to end tests for _clone_dim_order op
1 parent 83d8c75 commit cff39c9

File tree

2 files changed

+84
-14
lines changed

2 files changed

+84
-14
lines changed

exir/tests/test_memory_format_ops_pass.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
AmbiguousDimOrderError,
2828
MemoryFormatOpsPassTestUtils,
2929
MemoryFormatTestSet,
30+
PropagateToCloneChannelsLastModule,
3031
PropagateToCopyChannalsLastModule,
3132
SimpleCloneChannelsLastModule,
33+
SimpleCloneContiguousModule,
3234
SimpleEmptyChannelLastModule,
3335
SimpleEmptyContiguoustModule,
3436
SimpleToCopyChannelsLastModule,
@@ -92,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None:
9294
),
9395
)
9496

97+
def test_op_clone_replacement_contiguous(self) -> None:
98+
model = SimpleCloneContiguousModule()
99+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
100+
self,
101+
MemoryFormatTestSet(
102+
module=model.eval(),
103+
op=torch.ops.aten.clone.default,
104+
sample_input=(
105+
torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last),
106+
),
107+
target_memory_format=torch.contiguous_format,
108+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
109+
),
110+
)
111+
112+
def test_op_clone_replacement_channels_last(self) -> None:
113+
model = SimpleCloneChannelsLastModule()
114+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
115+
self,
116+
MemoryFormatTestSet(
117+
module=model.eval(),
118+
op=torch.ops.aten.clone.default,
119+
sample_input=(
120+
torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format),
121+
),
122+
target_memory_format=torch.channels_last,
123+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
124+
),
125+
)
126+
95127
def test_op_dim_order_update(self) -> None:
96128
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
97129
self,
@@ -129,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None:
129161
check_unambiguous_dim_order=True,
130162
)
131163

164+
def test_op_clone_dim_order_propagation(self) -> None:
165+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
166+
self,
167+
MemoryFormatTestSet(
168+
module=PropagateToCloneChannelsLastModule().eval(),
169+
op=torch.ops.aten.clone.default,
170+
sample_input=(
171+
torch.rand_like(
172+
torch.zeros([2, 2, 2, 2]),
173+
dtype=torch.float32,
174+
memory_format=torch.contiguous_format,
175+
),
176+
),
177+
target_memory_format=torch.channels_last,
178+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
179+
),
180+
check_unambiguous_dim_order=True,
181+
)
182+
132183
def test_op_dim_order_propagation_ambiguous(self) -> None:
133184
try:
134185
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
@@ -154,6 +205,29 @@ def test_op_dim_order_propagation_ambiguous(self) -> None:
154205
except AmbiguousDimOrderError:
155206
pass # Expected error
156207

208+
def test_op_clone_dim_order_graph_replacement(self):
209+
model = SimpleCloneChannelsLastModule()
210+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
211+
_clone_dim_order_op_str = (
212+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
213+
)
214+
215+
exported = export(model.eval(), (x,), strict=True)
216+
epm = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False))
217+
218+
# Verify one _clone_dim_order op exists and aten.clone.default nodes have been removed.
219+
(
220+
FileCheck()
221+
.check_not(
222+
"aten.clone.default"
223+
) # Check before first _clone_dim_order_op_str match.
224+
.check_count(_clone_dim_order_op_str, 1, exactly=True)
225+
.check_not(
226+
"aten.clone.default"
227+
) # Check after _clone_dim_order_op_str match.
228+
.run(epm.exported_program().graph_module.code)
229+
)
230+
157231
# Only test dim order replacement result in lean mode test.
158232
# This test is irrelevant with operator mode.
159233
def test_dim_order_replacement(self) -> None:
@@ -390,17 +464,3 @@ def test_mobilenet_v3_xnnpack(self) -> None:
390464
rtol=1e-3,
391465
),
392466
)
393-
394-
def test_op_clone_dim_order_registration(self):
395-
model = SimpleCloneChannelsLastModule()
396-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
397-
clone_dim_order_op_str = (
398-
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
399-
)
400-
401-
exported = export(model.eval(), (x,), strict=True)
402-
epm = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False))
403-
404-
FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run(
405-
epm.exported_program().graph_module.code
406-
)

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
122122
return t1 * t2
123123

124124

125+
class PropagateToCloneChannelsLastModule(torch.nn.Module):
126+
def __init__(self):
127+
super().__init__()
128+
129+
def forward(self, x: torch.Tensor) -> torch.Tensor:
130+
t1 = x.clone(memory_format=torch.channels_last)
131+
t2 = t1 + t1
132+
return t1 * t2
133+
134+
125135
class AmbiguousDimOrderError(RuntimeError):
126136
pass
127137

0 commit comments

Comments
 (0)