@@ -1082,6 +1082,27 @@ def keep(op):
10821082 return list (filter (keep , ops_to_not_decompose ))
10831083
10841084
1085+ def _can_skip_using_EDGE_DO_NOT_DECOMP (
1086+ partitioner : Dict [str , List [Partitioner ]], aten_programs : Dict [str , ExportedProgram ]
1087+ ) -> bool :
1088+ # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1089+ # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1090+ # fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1091+ # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1092+ # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1093+ # As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1094+ can_skip_using_EDGE_DO_NOT_DECOMP = True
1095+ for name , program in aten_programs .items ():
1096+ if partitioner is not None :
1097+ for curr_partitioner in partitioner .get (name , []):
1098+ curr_ops_no_decomp , check_op_support = (
1099+ curr_partitioner .ops_to_not_decompose (program )
1100+ )
1101+ if check_op_support is not None :
1102+ can_skip_using_EDGE_DO_NOT_DECOMP = False
1103+ return can_skip_using_EDGE_DO_NOT_DECOMP
1104+
1105+
10851106def _gen_edge_manager_for_partitioners (
10861107 partitioner : Dict [str , List [Partitioner ]],
10871108 aten_programs : Dict [str , ExportedProgram ],
@@ -1101,37 +1122,54 @@ def _gen_edge_manager_for_partitioners(
11011122 on nodes with preserved aten targets. They are then replaces with transformed ops to
11021123 keep them through the second pass of decompositions
11031124 """
1125+
1126+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1127+ partitioner , aten_programs
1128+ )
1129+
11041130 ops_set_to_not_decompose_by_program = {}
11051131 edge_programs : Dict [str , ExportedProgram ] = {}
11061132 for name , program in aten_programs .items ():
11071133 if partitioner is not None :
11081134 # preserve all ops listed by all partitioners first
11091135 all_ops_no_decomp = set ()
1136+
1137+ # This holds the subset of all_ops_no_decomp that actually need preservation, i.e.,
1138+ # the ones where the decomposition table has an entry for the op
1139+ all_ops_no_decomp_needing_preservation = []
11101140 for curr_partitioner in partitioner .get (name , []):
11111141 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1112- curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1113- curr_ops_no_decomp
1114- )
1142+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1143+ curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1144+ curr_ops_no_decomp
1145+ )
11151146 all_ops_no_decomp |= set (curr_ops_no_decomp )
11161147
11171148 table = _default_decomposition_table ()
11181149
11191150 for op in all_ops_no_decomp :
1120- table .pop (op , None )
1151+ if table .pop (op , None ) is not None :
1152+ all_ops_no_decomp_needing_preservation .append (op )
11211153
11221154 program = program .run_decompositions (table )
11231155 # Among all the preserved aten ops, use the check_op_fn to do an additional
11241156 # check on which ops need to be preserved and which ops need to be decomposed
11251157 # Those which are truly preserved will be replaced with transformed ops
1126- ops_set_to_not_decompose_by_program [name ] = (
1127- _replace_aten_ops_with_transformed_ops (name , program , partitioner ) or []
1128- )
1129- program = program .run_decompositions (_default_decomposition_table ())
1158+ if can_skip_using_EDGE_DO_NOT_DECOMP :
1159+ ops_set_to_not_decompose_by_program [name ] = (
1160+ all_ops_no_decomp_needing_preservation
1161+ )
1162+ else :
1163+ ops_set_to_not_decompose_by_program [name ] = (
1164+ _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1165+ or []
1166+ )
11301167
1131- _restore_transformed_ops_to_aten_ops (program )
1168+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1169+ program = program .run_decompositions (_default_decomposition_table ())
1170+ _restore_transformed_ops_to_aten_ops (program )
11321171
11331172 edge_programs [name ] = program
1134-
11351173 edge_programs [name ] = _generate_edge_program (
11361174 name ,
11371175 config ,
0 commit comments