@@ -1106,14 +1106,31 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
11061106 for name , program in aten_programs .items ():
11071107 if partitioner is not None :
11081108 for curr_partitioner in partitioner .get (name , []):
1109- curr_ops_no_decomp , check_op_support = (
1110- curr_partitioner .ops_to_not_decompose (program )
1111- )
1109+ (
1110+ curr_ops_no_decomp ,
1111+ check_op_support ,
1112+ ) = curr_partitioner .ops_to_not_decompose (program )
11121113 if check_op_support is not None :
11131114 can_skip_using_EDGE_DO_NOT_DECOMP = False
11141115 return can_skip_using_EDGE_DO_NOT_DECOMP
11151116
11161117
1118+ def _replace_view_with_view_copy (program : ExportedProgram ) -> ExportedProgram :
1119+ program = program .run_decompositions ({})
1120+ new_gm = ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module ).graph_module
1121+ program = ExportedProgram (
1122+ root = new_gm ,
1123+ graph = new_gm .graph ,
1124+ graph_signature = _get_updated_graph_signature (program .graph_signature , new_gm ),
1125+ state_dict = program .state_dict ,
1126+ range_constraints = program .range_constraints ,
1127+ module_call_graph = program .module_call_graph ,
1128+ example_inputs = program .example_inputs ,
1129+ constants = program .constants ,
1130+ )
1131+ return program
1132+
1133+
11171134def _gen_edge_manager_for_partitioners (
11181135 partitioner : Dict [str , List [Partitioner ]],
11191136 aten_programs : Dict [str , ExportedProgram ],
@@ -1133,58 +1150,55 @@ def _gen_edge_manager_for_partitioners(
11331150 on nodes with preserved aten targets. They are then replaces with transformed ops to
11341151 keep them through the second pass of decompositions
11351152 """
1153+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1154+ partitioner , aten_programs
1155+ )
11361156 ops_set_to_not_decompose_by_program = {}
11371157 edge_programs : Dict [str , ExportedProgram ] = {}
11381158 for name , program in aten_programs .items ():
1139- # Functionalize program without doing any decompositions
1140- program = program .run_decompositions ({})
1141- ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module )
1142-
1143- print (program )
1144-
11451159 if partitioner is not None :
11461160 # preserve all ops listed by all partitioners first
11471161 all_ops_no_decomp = set ()
1162+ all_ops_no_decomp_needing_preservation = []
11481163 for curr_partitioner in partitioner .get (name , []):
11491164 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1150- < << << << HEAD
1151- curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1152- curr_ops_no_decomp
1153- )
1154- == == == =
1155- >> >> >> > ec44f8478 (updates )
11561165 all_ops_no_decomp |= set (curr_ops_no_decomp )
1157-
1166+
11581167 # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1159- # Otherwise there will be issues
1168+ # Otherwise there will be issues
11601169 if not can_skip_using_EDGE_DO_NOT_DECOMP :
1161- all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (list (all_ops_no_decomp ))
1170+ all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1171+ list (all_ops_no_decomp )
1172+ )
11621173 all_ops_no_decomp = set (all_ops_no_decomp )
11631174
11641175 # Run default decompositions, except for those in all_ops_no_decomp
11651176 table = _default_decomposition_table ()
11661177 for op in all_ops_no_decomp :
1167- < << << << HEAD
1168- table .pop (op , None )
1169-
1170- == == == =
11711178 if table .pop (op , None ) is not None :
11721179 all_ops_no_decomp_needing_preservation .append (op )
1173- > >> >> >> ec44f8478 (updates )
11741180 program = program .run_decompositions (table )
11751181
11761182 # Among all the preserved aten ops, use the check_op_fn to do an additional
11771183 # check on which ops need to be preserved and which ops need to be decomposed
11781184 # Those which are truly preserved will be replaced with transformed ops
1179- ops_set_to_not_decompose_by_program [name ] = (
1180- _replace_aten_ops_with_transformed_ops (name , program , partitioner ) or []
1181- )
1182- program = program .run_decompositions (_default_decomposition_table ())
1185+ if can_skip_using_EDGE_DO_NOT_DECOMP :
1186+ ops_set_to_not_decompose_by_program [
1187+ name
1188+ ] = all_ops_no_decomp_needing_preservation
1189+ else :
1190+ ops_set_to_not_decompose_by_program [name ] = (
1191+ _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1192+ or []
1193+ )
11831194
1184- _restore_transformed_ops_to_aten_ops (program )
1195+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1196+ program = program .run_decompositions (_default_decomposition_table ())
1197+ _restore_transformed_ops_to_aten_ops (program )
11851198
1199+ # Edge will complain if there are view ops requested for preservation, so we replace them with view_copy
1200+ program = _replace_view_with_view_copy (program )
11861201 edge_programs [name ] = program
1187-
11881202 edge_programs [name ] = _generate_edge_program (
11891203 name ,
11901204 config ,
@@ -1229,7 +1243,7 @@ def collect_named_data_store_outputs(
12291243
12301244
12311245@et_logger ("to_edge_transform_and_lower" )
1232- def to_edge_transform_and_lower (
1246+ def to_edge_transform_and_lower ( # noqa: C901
12331247 programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
12341248 transform_passes : Optional [
12351249 Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
@@ -1294,6 +1308,9 @@ def to_edge_transform_and_lower(
12941308 elif partitioner is None :
12951309 partitioner = {name : [] for name in aten_programs .keys ()}
12961310
1311+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1312+ partitioner , aten_programs
1313+ )
12971314 edge_manager = _gen_edge_manager_for_partitioners (
12981315 partitioner , aten_programs , config , constant_methods
12991316 )
@@ -1319,7 +1336,8 @@ def to_edge_transform_and_lower(
13191336 curr_op_set , check_op_support = curr_partitioner .ops_to_not_decompose (
13201337 program
13211338 )
1322- curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
1339+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1340+ curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
13231341 ops_set_to_not_decompose = ops_set_to_not_decompose .union (curr_op_set )
13241342 _sanity_check_graph_for_non_decomp_ops (
13251343 name ,
0 commit comments