Skip to content

Commit 5d3550f

Browse files
authored
Allow overriding CoreML op support in partitioner to ignore ops where CoreML has bugs (pytorch#13023)
1 parent 0d9ce1c commit 5d3550f

File tree

1 file changed

+80
-27
lines changed

1 file changed

+80
-27
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PartitionResult,
2121
)
2222
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
23+
from executorch.exir.dialects._ops import ops as exir_ops
2324
from torch.export.exported_program import ExportedProgram
2425
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2526
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -56,6 +57,80 @@ def log_once(self, msg: str) -> None:
5657
logger.info(msg)
5758
self._logged_msgs.add(msg)
5859

60+
def should_skip_op_for_delegation(self, node_target_name: str) -> bool:
61+
skipped_ops = self.skip_ops_for_coreml_delegation or []
62+
if node_target_name in skipped_ops:
63+
assert (
64+
not self.lower_full_graph
65+
), f"Cannot skip {node_target_name} because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner"
66+
self.log_once(
67+
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
68+
+ node_target_name
69+
)
70+
return True
71+
return False
72+
73+
def should_override_support(self, node) -> bool:
74+
# https://github.com/apple/coremltools/issues/2573
75+
if (
76+
node.target
77+
in [
78+
torch.ops.aten.sub.Tensor,
79+
exir_ops.edge.aten.sub.Tensor,
80+
torch.ops.aten.add.Tensor,
81+
exir_ops.edge.aten.add.Tensor,
82+
]
83+
and "alpha" in node.kwargs
84+
and node.kwargs["alpha"] != 1
85+
):
86+
self.log_once(
87+
"torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support."
88+
)
89+
return True
90+
91+
# https://github.com/apple/coremltools/issues/2565
92+
if node.target in [
93+
torch.ops.aten.diagonal.default,
94+
torch.ops.aten.diagonal_copy.default,
95+
exir_ops.edge.aten.diagonal.default,
96+
exir_ops.edge.aten.diagonal_copy.default,
97+
]:
98+
self.log_once(
99+
"torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support."
100+
)
101+
return True
102+
103+
# https://github.com/apple/coremltools/issues/2569
104+
if node.target in [
105+
torch.ops.aten.acosh.default,
106+
exir_ops.edge.aten.acosh.default,
107+
torch.ops.aten.asinh.default,
108+
exir_ops.edge.aten.asinh.default,
109+
]:
110+
self.log_once(
111+
"torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support."
112+
)
113+
return True
114+
115+
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
116+
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
117+
# # in the placeholders due to partitioning, which CoreML does not support
118+
# if not self.lower_full_graph and any(
119+
# isinstance(arg, torch.fx.Node)
120+
# and isinstance(
121+
# arg.meta.get("val", None),
122+
# (torch.SymInt, torch.SymBool, torch.SymFloat),
123+
# )
124+
# for arg in node.args
125+
# ):
126+
# self.log_once(
127+
# "Skipping op for CoreML delegation because it contains symbolic args: "
128+
# + node_target_name
129+
# )
130+
# return True
131+
132+
return False
133+
59134
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
60135
# get_attr node can always be supported on any backend
61136
if node.op == "get_attr":
@@ -64,38 +139,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
64139
elif node.op == "call_function":
65140
# skip ops if specified by user
66141
node_target_name = getattr(node.target, "__name__", "").lower()
67-
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
68-
self.log_once(
69-
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
70-
+ node_target_name
71-
)
72-
assert (
73-
not self.lower_full_graph
74-
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
75-
return False
76142

77-
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
78-
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
79-
# # in the placeholders due to partitioning, which CoreML does not support
80-
# if not self.lower_full_graph and any(
81-
# isinstance(arg, torch.fx.Node)
82-
# and isinstance(
83-
# arg.meta.get("val", None),
84-
# (torch.SymInt, torch.SymBool, torch.SymFloat),
85-
# )
86-
# for arg in node.args
87-
# ):
88-
# self.log_once(
89-
# "Skipping op for CoreML delegation because it contains symbolic args: "
90-
# + node_target_name
91-
# )
92-
# assert not self.lower_full_graph
93-
# return False
143+
if self.should_skip_op_for_delegation(node_target_name):
144+
return False
94145

95146
# query coremltools to see if node is supported
96147
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
97148
node
98149
)
150+
if self.should_override_support(node):
151+
is_supported = False
152+
99153
if not is_supported:
100154
if self.lower_full_graph:
101155
raise NotImplementedError(
@@ -126,7 +180,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
126180

127181

128182
class CoreMLPartitioner(Partitioner):
129-
130183
def __init__(
131184
self,
132185
*,

0 commit comments

Comments
 (0)