diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 99aa2a0a60e..210ef307477 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -111,10 +111,16 @@ def ops_to_not_decompose( do_not_decompose = [] op_support = OperatorsSupportedForCoreMLBackend() for node in ep.graph.nodes: - if ( - node.op == "call_function" - and isinstance(node.target, torch._ops.OpOverload) - and op_support.is_node_supported(None, node) + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload ): - do_not_decompose.append(node.target) + try: + if op_support.is_node_supported(None, node): + do_not_decompose.append(node.target) + except Exception as e: + # CoreML's op_support.is_node_supported will sometimes throw + # for unsupported ops, rather than returning False + logger.warning( + f"Encountered exception when checking node support: {e}" + ) return do_not_decompose, None diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 03aac6a8611..7683d9c44d1 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -82,11 +82,28 @@ def test_vit_skip_conv(self): def test_ops_to_not_decompose(self): class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + def forward(self, q, k, v, mask): - return torch.ops.aten.scaled_dot_product_attention.default( + out = torch.ops.aten.scaled_dot_product_attention.default( q, k, v, attn_mask=mask ) + # Add non-functional and alias ops + # These will be removed by ExecuTorch in non-decomposition + # table because they cannot be functionalized + out = out.transpose(1, 2) + out = out.view(1, -1) + out = out.permute(0, 1) + out = out.add_(1.0) + out = out.mul_(2.0) + out = out.div_(3.0) + out = out.sub_(4.0) + out = torch.ops.aten.view_copy.default(out, (-1,)) + out = out.select(0, 0) + return out + model = Model() model.eval() diff --git a/exir/program/_program.py b/exir/program/_program.py index fdf4b93e19c..739765be0d5 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -26,6 +26,7 @@ from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap from executorch.exir.error import ExportError from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.operator.convert import _pybind_schema_to_native_schema from executorch.exir.pass_base import PassBase from executorch.exir.pass_manager import PassType from executorch.exir.passes import ( @@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops( ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose( program ) + ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose( + ops_set_to_not_decompose + ) for op_aten in ops_set_to_not_decompose: _register_no_decomp_op(op_aten) @@ -965,6 +969,47 @@ def _sanity_check_graph_for_non_decomp_ops( logging.warning(warning_str) +def _remove_invalid_ops_for_not_decompose( + ops_to_not_decompose: List[torch._ops.OpOverload], +) -> List[torch._ops.OpOverload]: + # To address https://github.com/pytorch/executorch/issues/8781 + def keep(op): + schema = op._schema + native_schema = _pybind_schema_to_native_schema(schema) + if native_schema.is_mutable: + logging.warn( + f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable." + ) + return False + + if native_schema.aliased_return_names() != [None]: + logging.warn( + f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output." + ) + return False + + # Explicit block list of ops that don't work if asked for + # preservation + if op in [ + # Hits infinte recursion error when op is in + # EDGE_DO_NOT_DECOMP namespace + torch.ops.aten._to_copy.default, + # scalar to tensor type promotion does not work on ops + # in EDGE_DO_NOT_DECOMP namespace + torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.div.Tensor, + ]: + logging.warn( + f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist." + ) + return False + return True + + return list(filter(keep, ops_to_not_decompose)) + + def _gen_edge_manager_for_partitioners( partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram], @@ -992,6 +1037,9 @@ def _gen_edge_manager_for_partitioners( all_ops_no_decomp = set() for curr_partitioner in partitioner.get(name, []): curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( + curr_ops_no_decomp + ) all_ops_no_decomp |= set(curr_ops_no_decomp) table = _default_decomposition_table() @@ -1113,6 +1161,7 @@ def to_edge_transform_and_lower( curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( program ) + curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) _sanity_check_graph_for_non_decomp_ops( name,