|
26 | 26 | from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap |
27 | 27 | from executorch.exir.error import ExportError |
28 | 28 | from executorch.exir.graph_module import get_control_flow_submodules |
| 29 | +from executorch.exir.operator.convert import _pybind_schema_to_native_schema |
29 | 30 | from executorch.exir.pass_base import PassBase |
30 | 31 | from executorch.exir.pass_manager import PassType |
31 | 32 | from executorch.exir.passes import ( |
@@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops( |
836 | 837 | ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose( |
837 | 838 | program |
838 | 839 | ) |
| 840 | + ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose( |
| 841 | + ops_set_to_not_decompose |
| 842 | + ) |
839 | 843 |
|
840 | 844 | for op_aten in ops_set_to_not_decompose: |
841 | 845 | _register_no_decomp_op(op_aten) |
@@ -965,6 +969,47 @@ def _sanity_check_graph_for_non_decomp_ops( |
965 | 969 | logging.warning(warning_str) |
966 | 970 |
|
967 | 971 |
|
| 972 | +def _remove_invalid_ops_for_not_decompose( |
| 973 | + ops_to_not_decompose: List[torch._ops.OpOverload], |
| 974 | +) -> List[torch._ops.OpOverload]: |
| 975 | + # To address https://github.com/pytorch/executorch/issues/8781 |
| 976 | + def keep(op): |
| 977 | + schema = op._schema |
| 978 | + native_schema = _pybind_schema_to_native_schema(schema) |
| 979 | + if native_schema.is_mutable: |
| 980 | + logging.warn( |
| 981 | + f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable." |
| 982 | + ) |
| 983 | + return False |
| 984 | + |
| 985 | + if native_schema.aliased_return_names() != [None]: |
| 986 | + logging.warn( |
| 987 | + f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output." |
| 988 | + ) |
| 989 | + return False |
| 990 | + |
| 991 | + # Explicit block list of ops that don't work if asked for |
| 992 | + # preservation |
| 993 | + if op in [ |
| 994 | + # Hits infinte recursion error when op is in |
| 995 | + # EDGE_DO_NOT_DECOMP namespace |
| 996 | + torch.ops.aten._to_copy.default, |
| 997 | + # scalar to tensor type promotion does not work on ops |
| 998 | + # in EDGE_DO_NOT_DECOMP namespace |
| 999 | + torch.ops.aten.mul.Tensor, |
| 1000 | + torch.ops.aten.add.Tensor, |
| 1001 | + torch.ops.aten.sub.Tensor, |
| 1002 | + torch.ops.aten.div.Tensor, |
| 1003 | + ]: |
| 1004 | + logging.warn( |
| 1005 | + f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist." |
| 1006 | + ) |
| 1007 | + return False |
| 1008 | + return True |
| 1009 | + |
| 1010 | + return list(filter(keep, ops_to_not_decompose)) |
| 1011 | + |
| 1012 | + |
968 | 1013 | def _gen_edge_manager_for_partitioners( |
969 | 1014 | partitioner: Dict[str, List[Partitioner]], |
970 | 1015 | aten_programs: Dict[str, ExportedProgram], |
@@ -992,6 +1037,9 @@ def _gen_edge_manager_for_partitioners( |
992 | 1037 | all_ops_no_decomp = set() |
993 | 1038 | for curr_partitioner in partitioner.get(name, []): |
994 | 1039 | curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) |
| 1040 | + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( |
| 1041 | + curr_ops_no_decomp |
| 1042 | + ) |
995 | 1043 | all_ops_no_decomp |= set(curr_ops_no_decomp) |
996 | 1044 |
|
997 | 1045 | table = _default_decomposition_table() |
@@ -1113,6 +1161,7 @@ def to_edge_transform_and_lower( |
1113 | 1161 | curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( |
1114 | 1162 | program |
1115 | 1163 | ) |
| 1164 | + curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) |
1116 | 1165 | ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) |
1117 | 1166 | _sanity_check_graph_for_non_decomp_ops( |
1118 | 1167 | name, |
|
0 commit comments