From dbf5b20a001e122b42c7298364ec773e7d29d0b1 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Mon, 21 Jul 2025 15:54:06 -0700 Subject: [PATCH] Move verification to after to_backend for to_edge_transform_and_lower (#12630) Summary: Some operators require preservation because they are intended to be consumed by a backend. These operators can contain view and mutation, as they won't be part of the graph after to_backend. If there are still view and mutation ops after to_backend, verification should throw an error. This diff: 1. Removes verification check from _generated_edge_program, which is called by to_edge and to_edge_transform_and_lower on the aten dialect. 2. to_edge: run verification for aten dialect (before to_edge) and edge dialect (after to_edge). 3. to_edge_transform_and_lower: only run the edge verification. Reviewed By: metascroy Differential Revision: D78535519 --- exir/program/_program.py | 65 ++++++++++++++++++++--------------- exir/verification/verifier.py | 2 +- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/exir/program/_program.py b/exir/program/_program.py index cc3bfac4e36..19c06da23bd 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -791,7 +791,6 @@ def edge_to_executorch_passes( def _generate_edge_program( - name: str, config: EdgeCompileConfig, program: ExportedProgram, core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None, @@ -799,7 +798,6 @@ def _generate_edge_program( ) -> ExportedProgram: """ Args: - name: The name of the program. config: The configuration for the edge program. program: The exported program to be converted to an edge program. core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten. @@ -807,39 +805,24 @@ def _generate_edge_program( Returns: An ExportedProgram in edge dialect. """ - # Remove invalid assert ops, such as _assert_tensor_metadata - gm = program.graph_module - gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm) - assert gm_res is not None - gm = gm_res.graph_module - # Remove unused parameters program = remove_unused_parameters_pass(program) - if config._check_ir_validity: - try: - EXIRATenDialectVerifier( - edge_compile_config=config, - class_only=False, - core_aten_ops_exception_list=core_aten_ops_exception_list, - preserve_ops=preserve_ops, - )(gm) - except ExportError as e: - logging.info(f"Input program {name} is not in ATen dialect.") - raise e - pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) - passes = [] - passes.append( - ReplaceViewOpsWithViewCopyOpsPass() - ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture + passes = [ + # Remove invalid assert ops, such as _assert_tensor_metadata + RemoveNonCoreAtenOpGraphAssertsPass(), + # TODO move inside aten_to_edge passes after all users are migrated off v1 capture + ReplaceViewOpsWithViewCopyOpsPass(), + ] passes.extend(pre_op_replace_passes) if config._use_edge_ops: passes.append(OpReplacePass()) if not config._skip_dim_order: passes.append(MemoryFormatOpsPass()) + gm = program.graph_module for p in passes: gm_res = p(gm) assert gm_res is not None @@ -1144,7 +1127,6 @@ def _gen_edge_manager_for_partitioners( edge_programs[name] = program edge_programs[name] = _generate_edge_program( - name, config, program, preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])), @@ -1288,11 +1270,12 @@ def to_edge_transform_and_lower( generate_error=True, ) + preserve_ops = config.preserve_ops + list(ops_set_to_not_decompose) if config._check_ir_validity: EXIREdgeDialectVerifier( edge_compile_config=config, class_only=True, - preserve_ops=list(ops_set_to_not_decompose), + preserve_ops=preserve_ops, )()(program.graph_module) return edge_manager @@ -1336,9 +1319,37 @@ def to_edge( for op in compile_config.preserve_ops: table.pop(op, None) program = program.run_decompositions(table) + + if config._check_ir_validity: + # Remove invalid assert ops, such as _assert_tensor_metadata. + # This pass is run in _generate_edge_program; it is required here to + # ensure the graph is in ATen dialect before verification. + gm = program.graph_module + gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm) + assert gm_res is not None + gm = gm_res.graph_module + try: + EXIRATenDialectVerifier( + edge_compile_config=config, + class_only=False, + )(gm) + except ExportError as e: + logging.info(f"Input program {name} is not in ATen dialect.") + raise e + edge_programs[name] = _generate_edge_program( - name, config, program, preserve_ops=preserve_ops + config, program, preserve_ops=preserve_ops ) + if config._check_ir_validity: + try: + EXIREdgeDialectVerifier( + edge_compile_config=config, + class_only=True, + preserve_ops=preserve_ops, + )()(edge_programs[name].graph_module) + except ExportError as e: + logging.info(f"Input program {name} is not in Edge dialect.") + raise e return EdgeProgramManager(edge_programs, constant_methods, config) diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index 6b79b924cd2..2c4a294d3e6 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -145,7 +145,7 @@ def check_valid_op(self, op): # which may affect memory planning. if op.is_view: raise RuntimeError( - f"Cannot preserve operator {op} because it is a view or mutation." + f"Cannot preserve operator {op} because it is a view." ) if op._schema.is_mutable: logging.warning(