@@ -1076,6 +1076,28 @@ def keep(op):
10761076 return list (filter (keep , preserve_ops ))
10771077
10781078
1079+ def _can_skip_using_EDGE_DO_NOT_DECOMP (
1080+ partitioner : Dict [str , List [Partitioner ]], aten_programs : Dict [str , ExportedProgram ]
1081+ ) -> bool :
1082+ # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1083+ # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1084+ # fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1085+ # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1086+ # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1087+ # As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1088+ can_skip_using_EDGE_DO_NOT_DECOMP = True
1089+ for name , program in aten_programs .items ():
1090+ if partitioner is not None :
1091+ for curr_partitioner in partitioner .get (name , []):
1092+ (
1093+ curr_ops_no_decomp ,
1094+ check_op_support ,
1095+ ) = curr_partitioner .ops_to_not_decompose (program )
1096+ if check_op_support is not None :
1097+ can_skip_using_EDGE_DO_NOT_DECOMP = False
1098+ return can_skip_using_EDGE_DO_NOT_DECOMP
1099+
1100+
10791101def _gen_edge_manager_for_partitioners (
10801102 partitioner : Dict [str , List [Partitioner ]],
10811103 aten_programs : Dict [str , ExportedProgram ],
@@ -1095,37 +1117,56 @@ def _gen_edge_manager_for_partitioners(
10951117 on nodes with preserved aten targets. They are then replaces with transformed ops to
10961118 keep them through the second pass of decompositions
10971119 """
1120+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1121+ partitioner , aten_programs
1122+ )
10981123 ops_set_to_not_decompose_by_program = {}
10991124 edge_programs : Dict [str , ExportedProgram ] = {}
11001125 for name , program in aten_programs .items ():
1126+ # Functionalize program before asking partitioners to preserve ops
1127+ program = program .run_decompositions ({})
1128+
11011129 if partitioner is not None :
11021130 # preserve all ops listed by all partitioners first
11031131 all_ops_no_decomp = set ()
1132+ all_ops_no_decomp_needing_preservation = []
11041133 for curr_partitioner in partitioner .get (name , []):
11051134 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1106- curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1107- curr_ops_no_decomp
1108- )
11091135 all_ops_no_decomp |= set (curr_ops_no_decomp )
11101136
1111- table = _default_decomposition_table ()
1137+ # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1138+ # Otherwise there will be issues
1139+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1140+ all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1141+ list (all_ops_no_decomp )
1142+ )
1143+ all_ops_no_decomp = set (all_ops_no_decomp )
11121144
1145+ # Run default decompositions, except for those in all_ops_no_decomp
1146+ table = _default_decomposition_table ()
11131147 for op in all_ops_no_decomp :
1114- table .pop (op , None )
1115-
1148+ if table .pop (op , None ) is not None :
1149+ all_ops_no_decomp_needing_preservation . append ( op )
11161150 program = program .run_decompositions (table )
1151+
11171152 # Among all the preserved aten ops, use the check_op_fn to do an additional
11181153 # check on which ops need to be preserved and which ops need to be decomposed
11191154 # Those which are truly preserved will be replaced with transformed ops
1120- ops_set_to_not_decompose_by_program [name ] = (
1121- _replace_aten_ops_with_transformed_ops (name , program , partitioner ) or []
1122- )
1123- program = program .run_decompositions (_default_decomposition_table ())
1155+ if can_skip_using_EDGE_DO_NOT_DECOMP :
1156+ ops_set_to_not_decompose_by_program [name ] = (
1157+ all_ops_no_decomp_needing_preservation
1158+ )
1159+ else :
1160+ ops_set_to_not_decompose_by_program [name ] = (
1161+ _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1162+ or []
1163+ )
11241164
1125- _restore_transformed_ops_to_aten_ops (program )
1165+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1166+ program = program .run_decompositions (_default_decomposition_table ())
1167+ _restore_transformed_ops_to_aten_ops (program )
11261168
11271169 edge_programs [name ] = program
1128-
11291170 edge_programs [name ] = _generate_edge_program (
11301171 config ,
11311172 program ,
@@ -1169,7 +1210,7 @@ def collect_named_data_store_outputs(
11691210
11701211
11711212@et_logger ("to_edge_transform_and_lower" )
1172- def to_edge_transform_and_lower (
1213+ def to_edge_transform_and_lower ( # noqa: C901
11731214 programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
11741215 transform_passes : Optional [
11751216 Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
@@ -1234,6 +1275,9 @@ def to_edge_transform_and_lower(
12341275 elif partitioner is None :
12351276 partitioner = {name : [] for name in aten_programs .keys ()}
12361277
1278+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1279+ partitioner , aten_programs
1280+ )
12371281 edge_manager = _gen_edge_manager_for_partitioners (
12381282 partitioner , aten_programs , config , constant_methods
12391283 )
@@ -1259,7 +1303,8 @@ def to_edge_transform_and_lower(
12591303 curr_op_set , check_op_support = curr_partitioner .ops_to_not_decompose (
12601304 program
12611305 )
1262- curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
1306+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1307+ curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
12631308 ops_set_to_not_decompose = ops_set_to_not_decompose .union (curr_op_set )
12641309 _sanity_check_graph_for_non_decomp_ops (
12651310 name ,
0 commit comments