|
23 | 23 | is_contiguous_dim_order, |
24 | 24 | ) |
25 | 25 | from executorch.exir.pass_base import ExportPass, ProxyValue |
26 | | -from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass |
27 | 26 |
|
28 | 27 | from executorch.exir.tests.test_memory_format_ops_pass_utils import ( |
29 | 28 | AmbiguousDimOrderError, |
30 | 29 | MemoryFormatOpsPassTestUtils, |
31 | 30 | MemoryFormatTestSet, |
32 | 31 | PropagateToCopyChannalsLastModule, |
33 | | - SimpleCloneChannelsLastModule, |
34 | 32 | SimpleEmptyChannelLastModule, |
35 | 33 | SimpleEmptyContiguoustModule, |
36 | 34 | SimpleToCopyChannelsLastModule, |
@@ -327,6 +325,58 @@ def call_operator(self, op, args, kwargs, meta): |
327 | 325 | self.assertTrue(is_contiguous_dim_order(actual)) |
328 | 326 | self.assertTrue(is_contiguous_dim_order(expected)) |
329 | 327 |
|
| 328 | + def test_op_clone_replacement_channels_last_survives(self): |
| 329 | + _clone_dim_order_op_str = ( |
| 330 | + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" |
| 331 | + ) |
| 332 | + |
| 333 | + model = SimpleCloneChannelsLastModule() |
| 334 | + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) |
| 335 | + |
| 336 | + exported = export(model.eval(), (x,), strict=True) |
| 337 | + before_epm = to_edge( |
| 338 | + exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) |
| 339 | + ) |
| 340 | + |
| 341 | + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) |
| 342 | + |
| 343 | + FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( |
| 344 | + updated_epm.exported_program().graph_module.code |
| 345 | + ) |
| 346 | + |
| 347 | + expected = before_epm.exported_program().module()(x) |
| 348 | + actual = updated_epm.exported_program().module()(x) |
| 349 | + assert torch.allclose(actual, expected) |
| 350 | + assert is_channel_last_dim_order(actual) |
| 351 | + |
| 352 | + def test_op_clone_without_transformation_removed(self): |
| 353 | + _clone_dim_order_op_str = ( |
| 354 | + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" |
| 355 | + ) |
| 356 | + |
| 357 | + model = SimpleCloneChannelsLastModule() |
| 358 | + x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) |
| 359 | + |
| 360 | + exported = export(model.eval(), (x,), strict=True) |
| 361 | + before_epm = to_edge( |
| 362 | + exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) |
| 363 | + ) |
| 364 | + |
| 365 | + FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run( |
| 366 | + before_epm.exported_program().graph_module.code |
| 367 | + ) |
| 368 | + |
| 369 | + updated_epm = before_epm.transform([RemoveCloneOpsTransform()]) |
| 370 | + |
| 371 | + FileCheck().check_not(_clone_dim_order_op_str).run( |
| 372 | + updated_epm.exported_program().graph_module.code |
| 373 | + ) |
| 374 | + |
| 375 | + expected = before_epm.exported_program().module()(x) |
| 376 | + actual = updated_epm.exported_program().module()(x) |
| 377 | + assert torch.allclose(actual, expected) |
| 378 | + assert is_channel_last_dim_order(actual) |
| 379 | + |
330 | 380 | def test_resnet18(self) -> None: |
331 | 381 | model = torchvision.models.resnet18() |
332 | 382 | MemoryFormatOpsPassTestUtils.memory_format_test_runner( |
@@ -392,56 +442,3 @@ def test_mobilenet_v3_xnnpack(self) -> None: |
392 | 442 | rtol=1e-3, |
393 | 443 | ), |
394 | 444 | ) |
395 | | - |
396 | | - def test_op_clone_replacement_channels_last_survives(self): |
397 | | - clone_dim_order_op_str = ( |
398 | | - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" |
399 | | - ) |
400 | | - |
401 | | - model = SimpleCloneChannelsLastModule() |
402 | | - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format) |
403 | | - |
404 | | - exported = export(model.eval(), (x,), strict=True) |
405 | | - before_epm = to_edge( |
406 | | - exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) |
407 | | - ) |
408 | | - |
409 | | - updated_epm = before_epm.transform([MemoryFormatOpsPass()]) |
410 | | - updated_epm = updated_epm.transform([RemoveCloneOpsTransform()]) |
411 | | - |
412 | | - FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( |
413 | | - updated_epm.exported_program().graph_module.code |
414 | | - ) |
415 | | - |
416 | | - expected = before_epm.exported_program().module()(x) |
417 | | - actual = updated_epm.exported_program().module()(x) |
418 | | - assert torch.allclose(actual, expected) |
419 | | - assert is_channel_last_dim_order(actual) |
420 | | - |
421 | | - def test_op_clone_without_transformation_removed(self): |
422 | | - clone_dim_order_op_str = ( |
423 | | - "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" |
424 | | - ) |
425 | | - |
426 | | - model = SimpleCloneChannelsLastModule() |
427 | | - x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last) |
428 | | - |
429 | | - exported = export(model.eval(), (x,), strict=True) |
430 | | - before_epm = to_edge( |
431 | | - exported, compile_config=EdgeCompileConfig(_skip_dim_order=False) |
432 | | - ) |
433 | | - |
434 | | - updated_epm = before_epm.transform([MemoryFormatOpsPass()]) |
435 | | - FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run( |
436 | | - updated_epm.exported_program().graph_module.code |
437 | | - ) |
438 | | - |
439 | | - updated_epm = updated_epm.transform([RemoveCloneOpsTransform()]) |
440 | | - FileCheck().check_not(clone_dim_order_op_str).run( |
441 | | - updated_epm.exported_program().graph_module.code |
442 | | - ) |
443 | | - |
444 | | - expected = before_epm.exported_program().module()(x) |
445 | | - actual = updated_epm.exported_program().module()(x) |
446 | | - assert torch.allclose(actual, expected) |
447 | | - assert is_channel_last_dim_order(actual) |
0 commit comments