From ebb9eaf7282b38b83a5ca45aa798299de3b16d18 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 13 Dec 2024 10:50:01 -0800 Subject: [PATCH] Add RemoveAssert pass to remove `_assert_tensor_metadata` nodes (#7277) Summary: `_assert_tensor_metadata` nodes is added to the result of exported graphs in D66988295. (More background in T209705957, mostly to relax aten.to constraint). Add a pass to remove this op when calling to_edge. Differential Revision: D67057219 --- exir/passes/remove_graph_asserts_pass.py | 23 +++++++++++++++++++++++ exir/program/_program.py | 15 ++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/exir/passes/remove_graph_asserts_pass.py b/exir/passes/remove_graph_asserts_pass.py index a46d4cadef4..870621876e8 100644 --- a/exir/passes/remove_graph_asserts_pass.py +++ b/exir/passes/remove_graph_asserts_pass.py @@ -29,6 +29,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: torch.ops.aten._assert_scalar.default, torch.ops.aten.sym_constrain_range_for_size.default, torch.ops.aten.sym_constrain_range.default, + torch.ops.aten._assert_tensor_metadata.default, ) ): module.graph.erase_node(node) @@ -37,3 +38,25 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: module.graph.eliminate_dead_code() return PassResult(graph_module, True) + + +class RemoveNonCoreAtenOpGraphAssertsPass(PassBase): + """ + Remove assert ops from the graph that're not Aten Canonical. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + for node in module.graph.nodes: + if node.op == "call_function" and ( + node.target in (torch.ops.aten._assert_tensor_metadata.default,) + ): + module.graph.erase_node(node) + + module.recompile() + module.graph.eliminate_dead_code() + + return PassResult(graph_module, True) diff --git a/exir/program/_program.py b/exir/program/_program.py index fd1d0aca3dc..2a30d1ed374 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -40,7 +40,10 @@ from executorch.exir.passes.normalize_view_copy_base_pass import ( NormalizeViewCopyBasePass, ) -from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass +from executorch.exir.passes.remove_graph_asserts_pass import ( + RemoveGraphAssertsPass, + RemoveNonCoreAtenOpGraphAssertsPass, +) from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge from executorch.exir.passes.replace_view_copy_with_view_pass import ( @@ -722,13 +725,20 @@ def _generate_edge_program( program: ExportedProgram, ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, ) -> ExportedProgram: + + # Remove invalid assert ops, such as _assert_tensor_metadata + gm = program.graph_module + gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm) + assert gm_res is not None + gm = gm_res.graph_module + if config._check_ir_validity: try: EXIRATenDialectVerifier( edge_compile_config=config, class_only=False, exception_list=ops_set_to_not_decompose, - )(program.graph_module) + )(gm) except ExportError as e: logging.info(f"Input program {name} is not in ATen dialect.") raise e @@ -745,7 +755,6 @@ def _generate_edge_program( if not config._skip_dim_order: passes.append(MemoryFormatOpsPass()) - gm = program.graph_module for p in passes: gm_res = p(gm) assert gm_res is not None