Skip to content

Commit d5a7f33

Browse files
authored
[reland] Fix coreml to edge transform and lower (#12629)
Re-land of: #12564 Previous attempt had conflict with #12306 that caused CI failure. ------ The current design of using EDGE_DO_NOT_DECOMP to prevent decomposition has long standing issues, and often fails lowering when certain ops are requested for preservation. This shows up most notably in the CoreML backend, where most ops are requested for preservation. As a band-aid, we introduced _remove_invalid_ops_for_not_decompose to cover certain kinds of ops. But when an op is encountered that we do not have an exception for, lowering still fails. We also recently found another bug that shows up for SDPA related to contiguous views (https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/) that we still do not fully understand the root cause of. EDGE_DO_NOT_DECOMP is actually only used to support the "check_op_support" argument in the partitioner; ops_to_not_decompose only modifies the default composition table. In CoreML's case, "check_op_support" is not used, and the issues with EDGE_DO_NOT_DECOMP's design causes lots of lowering issues that are hard to keep up with. This PR enables a new path that bypasses EDGE_DO_NOT_DECOMP's when possible (_can_skip_using_EDGE_DO_NOT_DECOMP). Long term, we need to address the buggy design of EDGE_DO_NOT_DECOMP. There are some ideas here: https://fb.workplace.com/groups/pytorch.edge2.team/permalink/1241898747065975/ cc @kimishpatel @YifanShenSZ @cymbalrush
1 parent 6c4f934 commit d5a7f33

File tree

1 file changed

+59
-14
lines changed

1 file changed

+59
-14
lines changed

exir/program/_program.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,28 @@ def keep(op):
10761076
return list(filter(keep, preserve_ops))
10771077

10781078

1079+
def _can_skip_using_EDGE_DO_NOT_DECOMP(
1080+
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
1081+
) -> bool:
1082+
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1083+
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1084+
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1085+
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1086+
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1087+
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1088+
can_skip_using_EDGE_DO_NOT_DECOMP = True
1089+
for name, program in aten_programs.items():
1090+
if partitioner is not None:
1091+
for curr_partitioner in partitioner.get(name, []):
1092+
(
1093+
curr_ops_no_decomp,
1094+
check_op_support,
1095+
) = curr_partitioner.ops_to_not_decompose(program)
1096+
if check_op_support is not None:
1097+
can_skip_using_EDGE_DO_NOT_DECOMP = False
1098+
return can_skip_using_EDGE_DO_NOT_DECOMP
1099+
1100+
10791101
def _gen_edge_manager_for_partitioners(
10801102
partitioner: Dict[str, List[Partitioner]],
10811103
aten_programs: Dict[str, ExportedProgram],
@@ -1095,37 +1117,56 @@ def _gen_edge_manager_for_partitioners(
10951117
on nodes with preserved aten targets. They are then replaces with transformed ops to
10961118
keep them through the second pass of decompositions
10971119
"""
1120+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1121+
partitioner, aten_programs
1122+
)
10981123
ops_set_to_not_decompose_by_program = {}
10991124
edge_programs: Dict[str, ExportedProgram] = {}
11001125
for name, program in aten_programs.items():
1126+
# Functionalize program before asking partitioners to preserve ops
1127+
program = program.run_decompositions({})
1128+
11011129
if partitioner is not None:
11021130
# preserve all ops listed by all partitioners first
11031131
all_ops_no_decomp = set()
1132+
all_ops_no_decomp_needing_preservation = []
11041133
for curr_partitioner in partitioner.get(name, []):
11051134
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1106-
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1107-
curr_ops_no_decomp
1108-
)
11091135
all_ops_no_decomp |= set(curr_ops_no_decomp)
11101136

1111-
table = _default_decomposition_table()
1137+
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1138+
# Otherwise there will be issues
1139+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1140+
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1141+
list(all_ops_no_decomp)
1142+
)
1143+
all_ops_no_decomp = set(all_ops_no_decomp)
11121144

1145+
# Run default decompositions, except for those in all_ops_no_decomp
1146+
table = _default_decomposition_table()
11131147
for op in all_ops_no_decomp:
1114-
table.pop(op, None)
1115-
1148+
if table.pop(op, None) is not None:
1149+
all_ops_no_decomp_needing_preservation.append(op)
11161150
program = program.run_decompositions(table)
1151+
11171152
# Among all the preserved aten ops, use the check_op_fn to do an additional
11181153
# check on which ops need to be preserved and which ops need to be decomposed
11191154
# Those which are truly preserved will be replaced with transformed ops
1120-
ops_set_to_not_decompose_by_program[name] = (
1121-
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
1122-
)
1123-
program = program.run_decompositions(_default_decomposition_table())
1155+
if can_skip_using_EDGE_DO_NOT_DECOMP:
1156+
ops_set_to_not_decompose_by_program[name] = (
1157+
all_ops_no_decomp_needing_preservation
1158+
)
1159+
else:
1160+
ops_set_to_not_decompose_by_program[name] = (
1161+
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1162+
or []
1163+
)
11241164

1125-
_restore_transformed_ops_to_aten_ops(program)
1165+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1166+
program = program.run_decompositions(_default_decomposition_table())
1167+
_restore_transformed_ops_to_aten_ops(program)
11261168

11271169
edge_programs[name] = program
1128-
11291170
edge_programs[name] = _generate_edge_program(
11301171
config,
11311172
program,
@@ -1169,7 +1210,7 @@ def collect_named_data_store_outputs(
11691210

11701211

11711212
@et_logger("to_edge_transform_and_lower")
1172-
def to_edge_transform_and_lower(
1213+
def to_edge_transform_and_lower( # noqa: C901
11731214
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
11741215
transform_passes: Optional[
11751216
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
@@ -1234,6 +1275,9 @@ def to_edge_transform_and_lower(
12341275
elif partitioner is None:
12351276
partitioner = {name: [] for name in aten_programs.keys()}
12361277

1278+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1279+
partitioner, aten_programs
1280+
)
12371281
edge_manager = _gen_edge_manager_for_partitioners(
12381282
partitioner, aten_programs, config, constant_methods
12391283
)
@@ -1259,7 +1303,8 @@ def to_edge_transform_and_lower(
12591303
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
12601304
program
12611305
)
1262-
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
1306+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1307+
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
12631308
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
12641309
_sanity_check_graph_for_non_decomp_ops(
12651310
name,

0 commit comments

Comments
 (0)