Skip to content

Commit 7c5bfc2

Browse files
committed
init
1 parent 9e05d89 commit 7c5bfc2

File tree

1 file changed

+48
-10
lines changed

1 file changed

+48
-10
lines changed

exir/program/_program.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,27 @@ def keep(op):
10821082
return list(filter(keep, ops_to_not_decompose))
10831083

10841084

1085+
def _can_skip_using_EDGE_DO_NOT_DECOMP(
1086+
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
1087+
) -> bool:
1088+
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1089+
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1090+
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1091+
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1092+
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1093+
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1094+
can_skip_using_EDGE_DO_NOT_DECOMP = True
1095+
for name, program in aten_programs.items():
1096+
if partitioner is not None:
1097+
for curr_partitioner in partitioner.get(name, []):
1098+
curr_ops_no_decomp, check_op_support = (
1099+
curr_partitioner.ops_to_not_decompose(program)
1100+
)
1101+
if check_op_support is not None:
1102+
can_skip_using_EDGE_DO_NOT_DECOMP = False
1103+
return can_skip_using_EDGE_DO_NOT_DECOMP
1104+
1105+
10851106
def _gen_edge_manager_for_partitioners(
10861107
partitioner: Dict[str, List[Partitioner]],
10871108
aten_programs: Dict[str, ExportedProgram],
@@ -1101,37 +1122,54 @@ def _gen_edge_manager_for_partitioners(
11011122
on nodes with preserved aten targets. They are then replaces with transformed ops to
11021123
keep them through the second pass of decompositions
11031124
"""
1125+
1126+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1127+
partitioner, aten_programs
1128+
)
1129+
11041130
ops_set_to_not_decompose_by_program = {}
11051131
edge_programs: Dict[str, ExportedProgram] = {}
11061132
for name, program in aten_programs.items():
11071133
if partitioner is not None:
11081134
# preserve all ops listed by all partitioners first
11091135
all_ops_no_decomp = set()
1136+
1137+
# This holds the subset of all_ops_no_decomp that actually need preservation, i.e.,
1138+
# the ones where the decomposition table has an entry for the op
1139+
all_ops_no_decomp_needing_preservation = []
11101140
for curr_partitioner in partitioner.get(name, []):
11111141
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1112-
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1113-
curr_ops_no_decomp
1114-
)
1142+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1143+
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1144+
curr_ops_no_decomp
1145+
)
11151146
all_ops_no_decomp |= set(curr_ops_no_decomp)
11161147

11171148
table = _default_decomposition_table()
11181149

11191150
for op in all_ops_no_decomp:
1120-
table.pop(op, None)
1151+
if table.pop(op, None) is not None:
1152+
all_ops_no_decomp_needing_preservation.append(op)
11211153

11221154
program = program.run_decompositions(table)
11231155
# Among all the preserved aten ops, use the check_op_fn to do an additional
11241156
# check on which ops need to be preserved and which ops need to be decomposed
11251157
# Those which are truly preserved will be replaced with transformed ops
1126-
ops_set_to_not_decompose_by_program[name] = (
1127-
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
1128-
)
1129-
program = program.run_decompositions(_default_decomposition_table())
1158+
if can_skip_using_EDGE_DO_NOT_DECOMP:
1159+
ops_set_to_not_decompose_by_program[name] = (
1160+
all_ops_no_decomp_needing_preservation
1161+
)
1162+
else:
1163+
ops_set_to_not_decompose_by_program[name] = (
1164+
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1165+
or []
1166+
)
11301167

1131-
_restore_transformed_ops_to_aten_ops(program)
1168+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1169+
program = program.run_decompositions(_default_decomposition_table())
1170+
_restore_transformed_ops_to_aten_ops(program)
11321171

11331172
edge_programs[name] = program
1134-
11351173
edge_programs[name] = _generate_edge_program(
11361174
name,
11371175
config,

0 commit comments

Comments
 (0)