Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,27 @@ def keep(op):
return list(filter(keep, ops_to_not_decompose))


def _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
) -> bool:
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
can_skip_using_EDGE_DO_NOT_DECOMP = True
for name, program in aten_programs.items():
if partitioner is not None:
for curr_partitioner in partitioner.get(name, []):
curr_ops_no_decomp, check_op_support = (
curr_partitioner.ops_to_not_decompose(program)
)
if check_op_support is not None:
can_skip_using_EDGE_DO_NOT_DECOMP = False
return can_skip_using_EDGE_DO_NOT_DECOMP


def _gen_edge_manager_for_partitioners(
partitioner: Dict[str, List[Partitioner]],
aten_programs: Dict[str, ExportedProgram],
Expand All @@ -1101,37 +1122,54 @@ def _gen_edge_manager_for_partitioners(
on nodes with preserved aten targets. They are then replaces with transformed ops to
keep them through the second pass of decompositions
"""

can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner, aten_programs
)

ops_set_to_not_decompose_by_program = {}
edge_programs: Dict[str, ExportedProgram] = {}
for name, program in aten_programs.items():
if partitioner is not None:
# preserve all ops listed by all partitioners first
all_ops_no_decomp = set()

# This holds the subset of all_ops_no_decomp that actually need preservation, i.e.,
# the ones where the decomposition table has an entry for the op
all_ops_no_decomp_needing_preservation = []
for curr_partitioner in partitioner.get(name, []):
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
curr_ops_no_decomp
)
if not can_skip_using_EDGE_DO_NOT_DECOMP:
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
curr_ops_no_decomp
)
all_ops_no_decomp |= set(curr_ops_no_decomp)

table = _default_decomposition_table()

for op in all_ops_no_decomp:
table.pop(op, None)
if table.pop(op, None) is not None:
all_ops_no_decomp_needing_preservation.append(op)

program = program.run_decompositions(table)
# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
)
program = program.run_decompositions(_default_decomposition_table())
if can_skip_using_EDGE_DO_NOT_DECOMP:
ops_set_to_not_decompose_by_program[name] = (
all_ops_no_decomp_needing_preservation
)
else:
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
or []
)

_restore_transformed_ops_to_aten_ops(program)
if not can_skip_using_EDGE_DO_NOT_DECOMP:
program = program.run_decompositions(_default_decomposition_table())
_restore_transformed_ops_to_aten_ops(program)

edge_programs[name] = program

edge_programs[name] = _generate_edge_program(
name,
config,
Expand Down
Loading