@@ -1113,6 +1113,7 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
11131113 can_skip_using_EDGE_DO_NOT_DECOMP = False
11141114 return can_skip_using_EDGE_DO_NOT_DECOMP
11151115
1116+
11161117def _gen_edge_manager_for_partitioners (
11171118 partitioner : Dict [str , List [Partitioner ]],
11181119 aten_programs : Dict [str , ExportedProgram ],
@@ -1135,22 +1136,43 @@ def _gen_edge_manager_for_partitioners(
11351136 ops_set_to_not_decompose_by_program = {}
11361137 edge_programs : Dict [str , ExportedProgram ] = {}
11371138 for name , program in aten_programs .items ():
1139+ # Functionalize program without doing any decompositions
1140+ program = program .run_decompositions ({})
1141+ ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module )
1142+
1143+ print (program )
1144+
11381145 if partitioner is not None :
11391146 # preserve all ops listed by all partitioners first
11401147 all_ops_no_decomp = set ()
11411148 for curr_partitioner in partitioner .get (name , []):
11421149 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1150+ < << << << HEAD
11431151 curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
11441152 curr_ops_no_decomp
11451153 )
1154+ == == == =
1155+ >> >> >> > ec44f8478 (updates )
11461156 all_ops_no_decomp |= set (curr_ops_no_decomp )
1147-
1157+
1158+ # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1159+ # Otherwise there will be issues
1160+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1161+ all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (list (all_ops_no_decomp ))
1162+ all_ops_no_decomp = set (all_ops_no_decomp )
1163+
1164+ # Run default decompositions, except for those in all_ops_no_decomp
11481165 table = _default_decomposition_table ()
1149-
11501166 for op in all_ops_no_decomp :
1167+ < << << << HEAD
11511168 table .pop (op , None )
11521169
1170+ == == == =
1171+ if table .pop (op , None ) is not None :
1172+ all_ops_no_decomp_needing_preservation .append (op )
1173+ > >> >> >> ec44f8478 (updates )
11531174 program = program .run_decompositions (table )
1175+
11541176 # Among all the preserved aten ops, use the check_op_fn to do an additional
11551177 # check on which ops need to be preserved and which ops need to be decomposed
11561178 # Those which are truly preserved will be replaced with transformed ops
0 commit comments