Skip to content

Commit 4ec3b3d

Browse files
committed
up
1 parent d4c78ab commit 4ec3b3d

File tree

1 file changed

+58
-21
lines changed

1 file changed

+58
-21
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PartitionResult,
2121
)
2222
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
23+
from executorch.exir.dialects._ops import ops as exir_ops
2324
from torch.export.exported_program import ExportedProgram
2425
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2526
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -56,15 +57,67 @@ def log_once(self, msg: str) -> None:
5657
logger.info(msg)
5758
self._logged_msgs.add(msg)
5859

60+
def should_skip_op_for_delegation(self, node_target_name: str) -> bool:
61+
skipped_ops = self.skip_ops_for_coreml_delegation or []
62+
63+
if node_target_name in skipped_ops:
64+
return True
65+
66+
# For backwards compatibility
67+
split_name = node_target_name.split("::")
68+
if len(split_name) == 2:
69+
namespace, name_without_namespace = split_name
70+
if namespace == "aten" and name_without_namespace in skipped_ops:
71+
return True
72+
73+
return False
74+
75+
def should_override_support(self, node) -> bool:
76+
# https://github.com/apple/coremltools/issues/2573
77+
if (
78+
node.target
79+
in [
80+
torch.ops.aten.sub.Tensor,
81+
exir_ops.edge.aten.sub.Tensor,
82+
torch.ops.aten.add.Tensor,
83+
exir_ops.edge.aten.add.Tensor,
84+
]
85+
and "alpha" in node.kwargs
86+
):
87+
self.log_once(
88+
"torch.ops.aten.{sub, add}.Tensor with alpha is not supported by CoreML. Overriding support."
89+
)
90+
return True
91+
92+
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
93+
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
94+
# # in the placeholders due to partitioning, which CoreML does not support
95+
# if not self.lower_full_graph and any(
96+
# isinstance(arg, torch.fx.Node)
97+
# and isinstance(
98+
# arg.meta.get("val", None),
99+
# (torch.SymInt, torch.SymBool, torch.SymFloat),
100+
# )
101+
# for arg in node.args
102+
# ):
103+
# self.log_once(
104+
# "Skipping op for CoreML delegation because it contains symbolic args: "
105+
# + node_target_name
106+
# )
107+
# return True
108+
109+
return False
110+
59111
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
60112
# get_attr node can always be supported on any backend
61113
if node.op == "get_attr":
62114
return True
63115
# check if the PyTorch op get called is supported in Core ML
64116
elif node.op == "call_function":
65117
# skip ops if specified by user
66-
node_target_name = getattr(node.target, "__name__", "").lower()
67-
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
118+
node_target_name = node.target.name().lower()
119+
120+
if self.should_skip_op_for_delegation(node_target_name):
68121
self.log_once(
69122
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
70123
+ node_target_name
@@ -74,28 +127,13 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
74127
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
75128
return False
76129

77-
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
78-
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
79-
# # in the placeholders due to partitioning, which CoreML does not support
80-
# if not self.lower_full_graph and any(
81-
# isinstance(arg, torch.fx.Node)
82-
# and isinstance(
83-
# arg.meta.get("val", None),
84-
# (torch.SymInt, torch.SymBool, torch.SymFloat),
85-
# )
86-
# for arg in node.args
87-
# ):
88-
# self.log_once(
89-
# "Skipping op for CoreML delegation because it contains symbolic args: "
90-
# + node_target_name
91-
# )
92-
# assert not self.lower_full_graph
93-
# return False
94-
95130
# query coremltools to see if node is supported
96131
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
97132
node
98133
)
134+
if self.should_override_support(node):
135+
is_supported = False
136+
99137
if not is_supported:
100138
if self.lower_full_graph:
101139
raise NotImplementedError(
@@ -126,7 +164,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
126164

127165

128166
class CoreMLPartitioner(Partitioner):
129-
130167
def __init__(
131168
self,
132169
*,

0 commit comments

Comments
 (0)