Skip to content

Commit ad74bdf

Browse files
committed
Remove explicit MemoryFormatOpsPass transform from clone_dim_order tests
1 parent f2f2932 commit ad74bdf

File tree

1 file changed

+52
-55
lines changed

1 file changed

+52
-55
lines changed

exir/tests/test_memory_format_ops_pass.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@
2323
is_contiguous_dim_order,
2424
)
2525
from executorch.exir.pass_base import ExportPass, ProxyValue
26-
from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass
2726

2827
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
2928
AmbiguousDimOrderError,
3029
MemoryFormatOpsPassTestUtils,
3130
MemoryFormatTestSet,
3231
PropagateToCopyChannalsLastModule,
33-
SimpleCloneChannelsLastModule,
3432
SimpleEmptyChannelLastModule,
3533
SimpleEmptyContiguoustModule,
3634
SimpleToCopyChannelsLastModule,
@@ -327,6 +325,58 @@ def call_operator(self, op, args, kwargs, meta):
327325
self.assertTrue(is_contiguous_dim_order(actual))
328326
self.assertTrue(is_contiguous_dim_order(expected))
329327

328+
def test_op_clone_replacement_channels_last_survives(self):
329+
_clone_dim_order_op_str = (
330+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
331+
)
332+
333+
model = SimpleCloneChannelsLastModule()
334+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
335+
336+
exported = export(model.eval(), (x,), strict=True)
337+
before_epm = to_edge(
338+
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
339+
)
340+
341+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
342+
343+
FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run(
344+
updated_epm.exported_program().graph_module.code
345+
)
346+
347+
expected = before_epm.exported_program().module()(x)
348+
actual = updated_epm.exported_program().module()(x)
349+
assert torch.allclose(actual, expected)
350+
assert is_channel_last_dim_order(actual)
351+
352+
def test_op_clone_without_transformation_removed(self):
353+
_clone_dim_order_op_str = (
354+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
355+
)
356+
357+
model = SimpleCloneChannelsLastModule()
358+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
359+
360+
exported = export(model.eval(), (x,), strict=True)
361+
before_epm = to_edge(
362+
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
363+
)
364+
365+
FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run(
366+
before_epm.exported_program().graph_module.code
367+
)
368+
369+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
370+
371+
FileCheck().check_not(_clone_dim_order_op_str).run(
372+
updated_epm.exported_program().graph_module.code
373+
)
374+
375+
expected = before_epm.exported_program().module()(x)
376+
actual = updated_epm.exported_program().module()(x)
377+
assert torch.allclose(actual, expected)
378+
assert is_channel_last_dim_order(actual)
379+
330380
def test_resnet18(self) -> None:
331381
model = torchvision.models.resnet18()
332382
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
@@ -392,56 +442,3 @@ def test_mobilenet_v3_xnnpack(self) -> None:
392442
rtol=1e-3,
393443
),
394444
)
395-
396-
def test_op_clone_replacement_channels_last_survives(self):
397-
clone_dim_order_op_str = (
398-
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
399-
)
400-
401-
model = SimpleCloneChannelsLastModule()
402-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
403-
404-
exported = export(model.eval(), (x,), strict=True)
405-
before_epm = to_edge(
406-
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
407-
)
408-
409-
updated_epm = before_epm.transform([MemoryFormatOpsPass()])
410-
updated_epm = updated_epm.transform([RemoveCloneOpsTransform()])
411-
412-
FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run(
413-
updated_epm.exported_program().graph_module.code
414-
)
415-
416-
expected = before_epm.exported_program().module()(x)
417-
actual = updated_epm.exported_program().module()(x)
418-
assert torch.allclose(actual, expected)
419-
assert is_channel_last_dim_order(actual)
420-
421-
def test_op_clone_without_transformation_removed(self):
422-
clone_dim_order_op_str = (
423-
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
424-
)
425-
426-
model = SimpleCloneChannelsLastModule()
427-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
428-
429-
exported = export(model.eval(), (x,), strict=True)
430-
before_epm = to_edge(
431-
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
432-
)
433-
434-
updated_epm = before_epm.transform([MemoryFormatOpsPass()])
435-
FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run(
436-
updated_epm.exported_program().graph_module.code
437-
)
438-
439-
updated_epm = updated_epm.transform([RemoveCloneOpsTransform()])
440-
FileCheck().check_not(clone_dim_order_op_str).run(
441-
updated_epm.exported_program().graph_module.code
442-
)
443-
444-
expected = before_epm.exported_program().module()(x)
445-
actual = updated_epm.exported_program().module()(x)
446-
assert torch.allclose(actual, expected)
447-
assert is_channel_last_dim_order(actual)

0 commit comments

Comments
 (0)