|
20 | 20 | from executorch.exir.error import ExportError |
21 | 21 | from executorch.exir.pass_manager import PassType |
22 | 22 | from executorch.exir.passes import ( |
| 23 | + base_post_op_replace_passes, |
| 24 | + base_pre_op_replace_passes, |
23 | 25 | EdgeToBackendOpsPass, |
| 26 | + MemoryFormatOpsPass, |
24 | 27 | OpReplacePass, |
25 | | - post_op_replace_passes, |
26 | | - pre_op_replace_passes, |
27 | 28 | ) |
28 | 29 | from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass |
29 | 30 | from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators |
@@ -459,6 +460,23 @@ def dump_exported_program(self) -> ExportedProgram: |
459 | 460 | return self.exported_program |
460 | 461 |
|
461 | 462 |
|
| 463 | +def _get_aten_to_edge_passes(config: EdgeCompileConfig): |
| 464 | + # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable |
| 465 | + # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play |
| 466 | + # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta. |
| 467 | + # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost. |
| 468 | + |
| 469 | + pre_op_replace_passes = base_pre_op_replace_passes + ( |
| 470 | + [] if config._skip_type_promotion else [RemoveMixedTypeOperators()] |
| 471 | + ) |
| 472 | + |
| 473 | + post_op_replace_passes = ( |
| 474 | + [] if config._skip_dim_order else [MemoryFormatOpsPass()] |
| 475 | + ) + base_post_op_replace_passes |
| 476 | + |
| 477 | + return pre_op_replace_passes, post_op_replace_passes |
| 478 | + |
| 479 | + |
462 | 480 | def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": |
463 | 481 | if config._check_ir_validity: |
464 | 482 | try: |
@@ -486,15 +504,9 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": |
486 | 504 | ), |
487 | 505 | False, |
488 | 506 | ) |
489 | | - # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable |
490 | | - # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play |
491 | | - # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta. |
492 | | - # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost. |
| 507 | + pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) |
493 | 508 |
|
494 | | - passes = pre_op_replace_passes + ( |
495 | | - [] if config._skip_type_promotion else [RemoveMixedTypeOperators()] |
496 | | - ) |
497 | | - new_ep = copy.deepcopy(ep).transform(*passes) |
| 509 | + new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes) |
498 | 510 | if dialect == "ATEN": |
499 | 511 | new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program) |
500 | 512 |
|
@@ -824,17 +836,13 @@ def to_edge( |
824 | 836 | logging.info(f"Input program {name} is not in ATen dialect.") |
825 | 837 | raise e |
826 | 838 |
|
827 | | - # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable |
828 | | - # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play |
829 | | - # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta. |
830 | | - # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost. |
| 839 | + pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) |
| 840 | + |
831 | 841 | passes = [] |
832 | 842 | passes.append( |
833 | 843 | ReplaceViewOpsWithViewCopyOpsPass() |
834 | 844 | ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture |
835 | 845 | passes.extend(pre_op_replace_passes) |
836 | | - if not config._skip_type_promotion: |
837 | | - passes.append(RemoveMixedTypeOperators()) |
838 | 846 | if config._use_edge_ops: |
839 | 847 | passes.append(OpReplacePass()) |
840 | 848 |
|
|
0 commit comments