Skip to content

Commit 5e492b2

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
move memory format pass into to_edge (pytorch#1891)
Summary: Pull Request resolved: pytorch#1891 This diff enable memory foramt pass by moving it into to_edge. Also introduced `_skip_dim_order` in edge compile config for gradually enable the pass. Currently we set its default as True, only enable it in `test_memory_format_ops_pass` for evaluation. Will graduatlly enable it in our system, and finally remove it from EdgeCompileConfig Reviewed By: larryliu0820 Differential Revision: D53567636 fbshipit-source-id: 969dea799d785d825e796386afcdfa7144a47ba9
1 parent bf7b701 commit 5e492b2

File tree

4 files changed

+40
-28
lines changed

4 files changed

+40
-28
lines changed

exir/capture/_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class EdgeCompileConfig:
3535
# TODO(larryliu): remove this
3636
_use_edge_ops: bool = True
3737
_skip_type_promotion: bool = False
38+
# TODO(gasoonjia): set it as False by default, and remove it in the long term
39+
_skip_dim_order: bool = True
3840

3941

4042
@compatibility(is_backward_compatible=False)

exir/passes/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
464464

465465
# Passes to convert a graph module from ATen to Edge IR
466466

467-
pre_op_replace_passes = PassManager(
467+
base_pre_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager(
468468
passes=[
469469
# ReplaceSymSizeOpPass need to be run before other passes which inherits
470470
# from ExportPass. ExportPass can not handle OpOverloadPacket in its
@@ -479,7 +479,9 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
479479
]
480480
).passes
481481

482-
post_op_replace_passes = PassManager(
482+
base_post_op_replace_passes: List[
483+
Callable[[torch.nn.Module], PassResult]
484+
] = PassManager(
483485
passes=[
484486
dead_code_elimination_pass,
485487
DebugHandleGeneratorPass(),

exir/program/_program.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from executorch.exir.error import ExportError
2121
from executorch.exir.pass_manager import PassType
2222
from executorch.exir.passes import (
23+
base_post_op_replace_passes,
24+
base_pre_op_replace_passes,
2325
EdgeToBackendOpsPass,
26+
MemoryFormatOpsPass,
2427
OpReplacePass,
25-
post_op_replace_passes,
26-
pre_op_replace_passes,
2728
)
2829
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
2930
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
@@ -459,6 +460,23 @@ def dump_exported_program(self) -> ExportedProgram:
459460
return self.exported_program
460461

461462

463+
def _get_aten_to_edge_passes(config: EdgeCompileConfig):
464+
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
465+
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
466+
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
467+
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
468+
469+
pre_op_replace_passes = base_pre_op_replace_passes + (
470+
[] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
471+
)
472+
473+
post_op_replace_passes = (
474+
[] if config._skip_dim_order else [MemoryFormatOpsPass()]
475+
) + base_post_op_replace_passes
476+
477+
return pre_op_replace_passes, post_op_replace_passes
478+
479+
462480
def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
463481
if config._check_ir_validity:
464482
try:
@@ -486,15 +504,9 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
486504
),
487505
False,
488506
)
489-
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
490-
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
491-
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
492-
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
507+
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
493508

494-
passes = pre_op_replace_passes + (
495-
[] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
496-
)
497-
new_ep = copy.deepcopy(ep).transform(*passes)
509+
new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes)
498510
if dialect == "ATEN":
499511
new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program)
500512

@@ -824,17 +836,13 @@ def to_edge(
824836
logging.info(f"Input program {name} is not in ATen dialect.")
825837
raise e
826838

827-
# TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable
828-
# use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play
829-
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
830-
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
839+
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
840+
831841
passes = []
832842
passes.append(
833843
ReplaceViewOpsWithViewCopyOpsPass()
834844
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
835845
passes.extend(pre_op_replace_passes)
836-
if not config._skip_type_promotion:
837-
passes.append(RemoveMixedTypeOperators())
838846
if config._use_edge_ops:
839847
passes.append(OpReplacePass())
840848

exir/tests/test_memory_format_ops_pass.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,23 @@ class TestSet:
7575
edge_op_str
7676
).run(before.graph_module.code)
7777

78-
ep = to_edge(
79-
before, compile_config=EdgeCompileConfig(_check_ir_validity=False)
80-
) # Only replacing edge_ops
81-
82-
# Run the pass
83-
# TODO move this in to_edge passes, make to_dim_copy pass verifier
84-
after = ep.transform([MemoryFormatOpsPass()], check_ir_validity=False)
78+
# TODO(gasoonjia): make to_dim_copy pass verifier
79+
epm = to_edge(
80+
before,
81+
compile_config=EdgeCompileConfig(
82+
_check_ir_validity=False, _skip_dim_order=False
83+
),
84+
)
8585

8686
# check op strings
8787
FileCheck().check_not(aten_op_str).check_count(
8888
edge_op_str, 1, exactly=True
89-
).run(after.exported_program().graph_module.code)
89+
).run(epm.exported_program().graph_module.code)
9090

9191
# check EdgeOp and the new BackendOp should behave the same
9292
expected = before(*test_set.sample_input)
93-
actual = after.exported_program()(*test_set.sample_input)
93+
actual = epm.exported_program()(*test_set.sample_input)
9494
self.assertTrue(torch.allclose(actual, expected))
9595

9696
# TODO - more
97-
after.to_executorch()
97+
epm.to_executorch()

0 commit comments

Comments
 (0)