diff --git a/exir/program/_program.py b/exir/program/_program.py index 8ef02f233ac..65e1155b0fc 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1082,6 +1082,27 @@ def keep(op): return list(filter(keep, ops_to_not_decompose)) +def _can_skip_using_EDGE_DO_NOT_DECOMP( + partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram] +) -> bool: + # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition + # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to + # fix some of the issues, but more issues are coming up over time, including a new issue with SDPA + # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/ + # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support + # As a temp fix, we give a more reliable path for backends that do not specify check_op_support + can_skip_using_EDGE_DO_NOT_DECOMP = True + for name, program in aten_programs.items(): + if partitioner is not None: + for curr_partitioner in partitioner.get(name, []): + curr_ops_no_decomp, check_op_support = ( + curr_partitioner.ops_to_not_decompose(program) + ) + if check_op_support is not None: + can_skip_using_EDGE_DO_NOT_DECOMP = False + return can_skip_using_EDGE_DO_NOT_DECOMP + + def _gen_edge_manager_for_partitioners( partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram], @@ -1101,37 +1122,54 @@ def _gen_edge_manager_for_partitioners( on nodes with preserved aten targets. They are then replaces with transformed ops to keep them through the second pass of decompositions """ + + can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP( + partitioner, aten_programs + ) + ops_set_to_not_decompose_by_program = {} edge_programs: Dict[str, ExportedProgram] = {} for name, program in aten_programs.items(): if partitioner is not None: # preserve all ops listed by all partitioners first all_ops_no_decomp = set() + + # This holds the subset of all_ops_no_decomp that actually need preservation, i.e., + # the ones where the decomposition table has an entry for the op + all_ops_no_decomp_needing_preservation = [] for curr_partitioner in partitioner.get(name, []): curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) - curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( - curr_ops_no_decomp - ) + if not can_skip_using_EDGE_DO_NOT_DECOMP: + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( + curr_ops_no_decomp + ) all_ops_no_decomp |= set(curr_ops_no_decomp) table = _default_decomposition_table() for op in all_ops_no_decomp: - table.pop(op, None) + if table.pop(op, None) is not None: + all_ops_no_decomp_needing_preservation.append(op) program = program.run_decompositions(table) # Among all the preserved aten ops, use the check_op_fn to do an additional # check on which ops need to be preserved and which ops need to be decomposed # Those which are truly preserved will be replaced with transformed ops - ops_set_to_not_decompose_by_program[name] = ( - _replace_aten_ops_with_transformed_ops(name, program, partitioner) or [] - ) - program = program.run_decompositions(_default_decomposition_table()) + if can_skip_using_EDGE_DO_NOT_DECOMP: + ops_set_to_not_decompose_by_program[name] = ( + all_ops_no_decomp_needing_preservation + ) + else: + ops_set_to_not_decompose_by_program[name] = ( + _replace_aten_ops_with_transformed_ops(name, program, partitioner) + or [] + ) - _restore_transformed_ops_to_aten_ops(program) + if not can_skip_using_EDGE_DO_NOT_DECOMP: + program = program.run_decompositions(_default_decomposition_table()) + _restore_transformed_ops_to_aten_ops(program) edge_programs[name] = program - edge_programs[name] = _generate_edge_program( name, config,