@@ -1096,6 +1096,7 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
10961096 can_skip_using_EDGE_DO_NOT_DECOMP = False
10971097 return can_skip_using_EDGE_DO_NOT_DECOMP
10981098
1099+
10991100def _gen_edge_manager_for_partitioners (
11001101 partitioner : Dict [str , List [Partitioner ]],
11011102 aten_programs : Dict [str , ExportedProgram ],
@@ -1118,22 +1119,43 @@ def _gen_edge_manager_for_partitioners(
11181119 ops_set_to_not_decompose_by_program = {}
11191120 edge_programs : Dict [str , ExportedProgram ] = {}
11201121 for name , program in aten_programs .items ():
1122+ # Functionalize program without doing any decompositions
1123+ program = program .run_decompositions ({})
1124+ ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module )
1125+
1126+ print (program )
1127+
11211128 if partitioner is not None :
11221129 # preserve all ops listed by all partitioners first
11231130 all_ops_no_decomp = set ()
11241131 for curr_partitioner in partitioner .get (name , []):
11251132 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1133+ < << << << HEAD
11261134 curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
11271135 curr_ops_no_decomp
11281136 )
1137+ == == == =
1138+ >> >> >> > ec44f8478 (updates )
11291139 all_ops_no_decomp |= set (curr_ops_no_decomp )
1130-
1140+
1141+ # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1142+ # Otherwise there will be issues
1143+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1144+ all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (list (all_ops_no_decomp ))
1145+ all_ops_no_decomp = set (all_ops_no_decomp )
1146+
1147+ # Run default decompositions, except for those in all_ops_no_decomp
11311148 table = _default_decomposition_table ()
1132-
11331149 for op in all_ops_no_decomp :
1150+ < << << << HEAD
11341151 table .pop (op , None )
11351152
1153+ == == == =
1154+ if table .pop (op , None ) is not None :
1155+ all_ops_no_decomp_needing_preservation .append (op )
1156+ > >> >> >> ec44f8478 (updates )
11361157 program = program .run_decompositions (table )
1158+
11371159 # Among all the preserved aten ops, use the check_op_fn to do an additional
11381160 # check on which ops need to be preserved and which ops need to be decomposed
11391161 # Those which are truly preserved will be replaced with transformed ops
0 commit comments