@@ -1089,14 +1089,31 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
10891089 for name , program in aten_programs .items ():
10901090 if partitioner is not None :
10911091 for curr_partitioner in partitioner .get (name , []):
1092- curr_ops_no_decomp , check_op_support = (
1093- curr_partitioner .ops_to_not_decompose (program )
1094- )
1092+ (
1093+ curr_ops_no_decomp ,
1094+ check_op_support ,
1095+ ) = curr_partitioner .ops_to_not_decompose (program )
10951096 if check_op_support is not None :
10961097 can_skip_using_EDGE_DO_NOT_DECOMP = False
10971098 return can_skip_using_EDGE_DO_NOT_DECOMP
10981099
10991100
1101+ def _replace_view_with_view_copy (program : ExportedProgram ) -> ExportedProgram :
1102+ program = program .run_decompositions ({})
1103+ new_gm = ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module ).graph_module
1104+ program = ExportedProgram (
1105+ root = new_gm ,
1106+ graph = new_gm .graph ,
1107+ graph_signature = _get_updated_graph_signature (program .graph_signature , new_gm ),
1108+ state_dict = program .state_dict ,
1109+ range_constraints = program .range_constraints ,
1110+ module_call_graph = program .module_call_graph ,
1111+ example_inputs = program .example_inputs ,
1112+ constants = program .constants ,
1113+ )
1114+ return program
1115+
1116+
11001117def _gen_edge_manager_for_partitioners (
11011118 partitioner : Dict [str , List [Partitioner ]],
11021119 aten_programs : Dict [str , ExportedProgram ],
@@ -1116,58 +1133,55 @@ def _gen_edge_manager_for_partitioners(
11161133 on nodes with preserved aten targets. They are then replaces with transformed ops to
11171134 keep them through the second pass of decompositions
11181135 """
1136+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1137+ partitioner , aten_programs
1138+ )
11191139 ops_set_to_not_decompose_by_program = {}
11201140 edge_programs : Dict [str , ExportedProgram ] = {}
11211141 for name , program in aten_programs .items ():
1122- # Functionalize program without doing any decompositions
1123- program = program .run_decompositions ({})
1124- ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module )
1125-
1126- print (program )
1127-
11281142 if partitioner is not None :
11291143 # preserve all ops listed by all partitioners first
11301144 all_ops_no_decomp = set ()
1145+ all_ops_no_decomp_needing_preservation = []
11311146 for curr_partitioner in partitioner .get (name , []):
11321147 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1133- < << << << HEAD
1134- curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1135- curr_ops_no_decomp
1136- )
1137- == == == =
1138- >> >> >> > ec44f8478 (updates )
11391148 all_ops_no_decomp |= set (curr_ops_no_decomp )
1140-
1149+
11411150 # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1142- # Otherwise there will be issues
1151+ # Otherwise there will be issues
11431152 if not can_skip_using_EDGE_DO_NOT_DECOMP :
1144- all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (list (all_ops_no_decomp ))
1153+ all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1154+ list (all_ops_no_decomp )
1155+ )
11451156 all_ops_no_decomp = set (all_ops_no_decomp )
11461157
11471158 # Run default decompositions, except for those in all_ops_no_decomp
11481159 table = _default_decomposition_table ()
11491160 for op in all_ops_no_decomp :
1150- < << << << HEAD
1151- table .pop (op , None )
1152-
1153- == == == =
11541161 if table .pop (op , None ) is not None :
11551162 all_ops_no_decomp_needing_preservation .append (op )
1156- > >> >> >> ec44f8478 (updates )
11571163 program = program .run_decompositions (table )
11581164
11591165 # Among all the preserved aten ops, use the check_op_fn to do an additional
11601166 # check on which ops need to be preserved and which ops need to be decomposed
11611167 # Those which are truly preserved will be replaced with transformed ops
1162- ops_set_to_not_decompose_by_program [name ] = (
1163- _replace_aten_ops_with_transformed_ops (name , program , partitioner ) or []
1164- )
1165- program = program .run_decompositions (_default_decomposition_table ())
1168+ if can_skip_using_EDGE_DO_NOT_DECOMP :
1169+ ops_set_to_not_decompose_by_program [
1170+ name
1171+ ] = all_ops_no_decomp_needing_preservation
1172+ else :
1173+ ops_set_to_not_decompose_by_program [name ] = (
1174+ _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1175+ or []
1176+ )
11661177
1167- _restore_transformed_ops_to_aten_ops (program )
1178+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1179+ program = program .run_decompositions (_default_decomposition_table ())
1180+ _restore_transformed_ops_to_aten_ops (program )
11681181
1182+ # Edge will complain if there are view ops requested for preservation, so we replace them with view_copy
1183+ program = _replace_view_with_view_copy (program )
11691184 edge_programs [name ] = program
1170-
11711185 edge_programs [name ] = _generate_edge_program (
11721186 config ,
11731187 program ,
@@ -1211,7 +1225,7 @@ def collect_named_data_store_outputs(
12111225
12121226
12131227@et_logger ("to_edge_transform_and_lower" )
1214- def to_edge_transform_and_lower (
1228+ def to_edge_transform_and_lower ( # noqa: C901
12151229 programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
12161230 transform_passes : Optional [
12171231 Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
@@ -1276,6 +1290,9 @@ def to_edge_transform_and_lower(
12761290 elif partitioner is None :
12771291 partitioner = {name : [] for name in aten_programs .keys ()}
12781292
1293+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1294+ partitioner , aten_programs
1295+ )
12791296 edge_manager = _gen_edge_manager_for_partitioners (
12801297 partitioner , aten_programs , config , constant_methods
12811298 )
@@ -1301,7 +1318,8 @@ def to_edge_transform_and_lower(
13011318 curr_op_set , check_op_support = curr_partitioner .ops_to_not_decompose (
13021319 program
13031320 )
1304- curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
1321+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1322+ curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
13051323 ops_set_to_not_decompose = ops_set_to_not_decompose .union (curr_op_set )
13061324 _sanity_check_graph_for_non_decomp_ops (
13071325 name ,
0 commit comments