@@ -791,55 +791,38 @@ def edge_to_executorch_passes(
791791
792792
793793def _generate_edge_program (
794- name : str ,
795794 config : EdgeCompileConfig ,
796795 program : ExportedProgram ,
797796 core_aten_ops_exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
798797 preserve_ops : Optional [List [torch ._ops .OpOverload ]] = None ,
799798) -> ExportedProgram :
800799 """
801800 Args:
802- name: The name of the program.
803801 config: The configuration for the edge program.
804802 program: The exported program to be converted to an edge program.
805803 core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
806804 preserve_ops: A list of aten ops that should not be decomposed.
807805 Returns:
808806 An ExportedProgram in edge dialect.
809807 """
810- # Remove invalid assert ops, such as _assert_tensor_metadata
811- gm = program .graph_module
812- gm_res = RemoveNonCoreAtenOpGraphAssertsPass ()(gm )
813- assert gm_res is not None
814- gm = gm_res .graph_module
815-
816808 # Remove unused parameters
817809 program = remove_unused_parameters_pass (program )
818810
819- if config ._check_ir_validity :
820- try :
821- EXIRATenDialectVerifier (
822- edge_compile_config = config ,
823- class_only = False ,
824- core_aten_ops_exception_list = core_aten_ops_exception_list ,
825- preserve_ops = preserve_ops ,
826- )(gm )
827- except ExportError as e :
828- logging .info (f"Input program { name } is not in ATen dialect." )
829- raise e
830-
831811 pre_op_replace_passes , post_op_replace_passes = _get_aten_to_edge_passes (config )
832812
833- passes = []
834- passes .append (
835- ReplaceViewOpsWithViewCopyOpsPass ()
836- ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
813+ passes = [
814+ # Remove invalid assert ops, such as _assert_tensor_metadata
815+ RemoveNonCoreAtenOpGraphAssertsPass (),
816+ # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
817+ ReplaceViewOpsWithViewCopyOpsPass (),
818+ ]
837819 passes .extend (pre_op_replace_passes )
838820 if config ._use_edge_ops :
839821 passes .append (OpReplacePass ())
840822 if not config ._skip_dim_order :
841823 passes .append (MemoryFormatOpsPass ())
842824
825+ gm = program .graph_module
843826 for p in passes :
844827 gm_res = p (gm )
845828 assert gm_res is not None
@@ -1144,7 +1127,6 @@ def _gen_edge_manager_for_partitioners(
11441127 edge_programs [name ] = program
11451128
11461129 edge_programs [name ] = _generate_edge_program (
1147- name ,
11481130 config ,
11491131 program ,
11501132 preserve_ops = list (ops_set_to_not_decompose_by_program .get (name , [])),
@@ -1288,11 +1270,12 @@ def to_edge_transform_and_lower(
12881270 generate_error = True ,
12891271 )
12901272
1273+ preserve_ops = config .preserve_ops + list (ops_set_to_not_decompose )
12911274 if config ._check_ir_validity :
12921275 EXIREdgeDialectVerifier (
12931276 edge_compile_config = config ,
12941277 class_only = True ,
1295- preserve_ops = list ( ops_set_to_not_decompose ) ,
1278+ preserve_ops = preserve_ops ,
12961279 )()(program .graph_module )
12971280
12981281 return edge_manager
@@ -1336,9 +1319,36 @@ def to_edge(
13361319 for op in compile_config .preserve_ops :
13371320 table .pop (op , None )
13381321 program = program .run_decompositions (table )
1322+
1323+ if config ._check_ir_validity :
1324+ # Remove invalid assert ops, such as _assert_tensor_metadata, before verification.
1325+ # This pass is run in _generate_edge_program.
1326+ gm = program .graph_module
1327+ gm_res = RemoveNonCoreAtenOpGraphAssertsPass ()(gm )
1328+ assert gm_res is not None
1329+ gm = gm_res .graph_module
1330+ try :
1331+ EXIRATenDialectVerifier (
1332+ edge_compile_config = config ,
1333+ class_only = False ,
1334+ )(gm )
1335+ except ExportError as e :
1336+ logging .info (f"Input program { name } is not in ATen dialect." )
1337+ raise e
1338+
13391339 edge_programs [name ] = _generate_edge_program (
1340- name , config , program , preserve_ops = preserve_ops
1340+ config , program , preserve_ops = preserve_ops
13411341 )
1342+ if config ._check_ir_validity :
1343+ try :
1344+ EXIREdgeDialectVerifier (
1345+ edge_compile_config = config ,
1346+ class_only = True ,
1347+ preserve_ops = preserve_ops ,
1348+ )()(edge_programs [name ].graph_module )
1349+ except ExportError as e :
1350+ logging .info (f"Input program { name } is not in Edge dialect." )
1351+ raise e
13421352
13431353 return EdgeProgramManager (edge_programs , constant_methods , config )
13441354
0 commit comments