@@ -1094,6 +1094,27 @@ def keep(op):
10941094 return list (filter (keep , preserve_ops ))
10951095
10961096
1097+ def _can_skip_using_EDGE_DO_NOT_DECOMP (
1098+ partitioner : Dict [str , List [Partitioner ]], aten_programs : Dict [str , ExportedProgram ]
1099+ ) -> bool :
1100+ # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1101+ # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1102+ # fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1103+ # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1104+ # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1105+ # As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1106+ can_skip_using_EDGE_DO_NOT_DECOMP = True
1107+ for name , program in aten_programs .items ():
1108+ if partitioner is not None :
1109+ for curr_partitioner in partitioner .get (name , []):
1110+ curr_ops_no_decomp , check_op_support = (
1111+ curr_partitioner .ops_to_not_decompose (program )
1112+ )
1113+ if check_op_support is not None :
1114+ can_skip_using_EDGE_DO_NOT_DECOMP = False
1115+ return can_skip_using_EDGE_DO_NOT_DECOMP
1116+
1117+
10971118def _gen_edge_manager_for_partitioners (
10981119 partitioner : Dict [str , List [Partitioner ]],
10991120 aten_programs : Dict [str , ExportedProgram ],
@@ -1113,37 +1134,54 @@ def _gen_edge_manager_for_partitioners(
11131134 on nodes with preserved aten targets. They are then replaces with transformed ops to
11141135 keep them through the second pass of decompositions
11151136 """
1137+
1138+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1139+ partitioner , aten_programs
1140+ )
1141+
11161142 ops_set_to_not_decompose_by_program = {}
11171143 edge_programs : Dict [str , ExportedProgram ] = {}
11181144 for name , program in aten_programs .items ():
11191145 if partitioner is not None :
11201146 # preserve all ops listed by all partitioners first
11211147 all_ops_no_decomp = set ()
1148+
1149+ # This holds the subset of all_ops_no_decomp that actually need preservation, i.e.,
1150+ # the ones where the decomposition table has an entry for the op
1151+ all_ops_no_decomp_needing_preservation = []
11221152 for curr_partitioner in partitioner .get (name , []):
11231153 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1124- curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1125- curr_ops_no_decomp
1126- )
1154+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1155+ curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1156+ curr_ops_no_decomp
1157+ )
11271158 all_ops_no_decomp |= set (curr_ops_no_decomp )
11281159
11291160 table = _default_decomposition_table ()
11301161
11311162 for op in all_ops_no_decomp :
1132- table .pop (op , None )
1163+ if table .pop (op , None ) is not None :
1164+ all_ops_no_decomp_needing_preservation .append (op )
11331165
11341166 program = program .run_decompositions (table )
11351167 # Among all the preserved aten ops, use the check_op_fn to do an additional
11361168 # check on which ops need to be preserved and which ops need to be decomposed
11371169 # Those which are truly preserved will be replaced with transformed ops
1138- ops_set_to_not_decompose_by_program [name ] = (
1139- _replace_aten_ops_with_transformed_ops (name , program , partitioner ) or []
1140- )
1141- program = program .run_decompositions (_default_decomposition_table ())
1170+ if can_skip_using_EDGE_DO_NOT_DECOMP :
1171+ ops_set_to_not_decompose_by_program [name ] = (
1172+ all_ops_no_decomp_needing_preservation
1173+ )
1174+ else :
1175+ ops_set_to_not_decompose_by_program [name ] = (
1176+ _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1177+ or []
1178+ )
11421179
1143- _restore_transformed_ops_to_aten_ops (program )
1180+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1181+ program = program .run_decompositions (_default_decomposition_table ())
1182+ _restore_transformed_ops_to_aten_ops (program )
11441183
11451184 edge_programs [name ] = program
1146-
11471185 edge_programs [name ] = _generate_edge_program (
11481186 name ,
11491187 config ,
0 commit comments