20
20
PartitionResult ,
21
21
)
22
22
from executorch .exir .backend .utils import tag_constant_data , tag_mutated_buffer
23
+ from executorch .exir .dialects ._ops import ops as exir_ops
23
24
from torch .export .exported_program import ExportedProgram
24
25
from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
25
26
from torch .fx .passes .operator_support import OperatorSupportBase
@@ -56,6 +57,80 @@ def log_once(self, msg: str) -> None:
56
57
logger .info (msg )
57
58
self ._logged_msgs .add (msg )
58
59
60
+ def should_skip_op_for_delegation (self , node_target_name : str ) -> bool :
61
+ skipped_ops = self .skip_ops_for_coreml_delegation or []
62
+ if node_target_name in skipped_ops :
63
+ assert (
64
+ not self .lower_full_graph
65
+ ), f"Cannot skip { node_target_name } because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner"
66
+ self .log_once (
67
+ "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
68
+ + node_target_name
69
+ )
70
+ return True
71
+ return False
72
+
73
+ def should_override_support (self , node ) -> bool :
74
+ # https://github.com/apple/coremltools/issues/2573
75
+ if (
76
+ node .target
77
+ in [
78
+ torch .ops .aten .sub .Tensor ,
79
+ exir_ops .edge .aten .sub .Tensor ,
80
+ torch .ops .aten .add .Tensor ,
81
+ exir_ops .edge .aten .add .Tensor ,
82
+ ]
83
+ and "alpha" in node .kwargs
84
+ and node .kwargs ["alpha" ] != 1
85
+ ):
86
+ self .log_once (
87
+ "torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support."
88
+ )
89
+ return True
90
+
91
+ # https://github.com/apple/coremltools/issues/2565
92
+ if node .target in [
93
+ torch .ops .aten .diagonal .default ,
94
+ torch .ops .aten .diagonal_copy .default ,
95
+ exir_ops .edge .aten .diagonal .default ,
96
+ exir_ops .edge .aten .diagonal_copy .default ,
97
+ ]:
98
+ self .log_once (
99
+ "torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support."
100
+ )
101
+ return True
102
+
103
+ # https://github.com/apple/coremltools/issues/2569
104
+ if node .target in [
105
+ torch .ops .aten .acosh .default ,
106
+ exir_ops .edge .aten .acosh .default ,
107
+ torch .ops .aten .asinh .default ,
108
+ exir_ops .edge .aten .asinh .default ,
109
+ ]:
110
+ self .log_once (
111
+ "torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support."
112
+ )
113
+ return True
114
+
115
+ # TODO: enable this after bugs in ExecuTorch's partitioner are fixed
116
+ # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
117
+ # # in the placeholders due to partitioning, which CoreML does not support
118
+ # if not self.lower_full_graph and any(
119
+ # isinstance(arg, torch.fx.Node)
120
+ # and isinstance(
121
+ # arg.meta.get("val", None),
122
+ # (torch.SymInt, torch.SymBool, torch.SymFloat),
123
+ # )
124
+ # for arg in node.args
125
+ # ):
126
+ # self.log_once(
127
+ # "Skipping op for CoreML delegation because it contains symbolic args: "
128
+ # + node_target_name
129
+ # )
130
+ # return True
131
+
132
+ return False
133
+
59
134
def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
60
135
# get_attr node can always be supported on any backend
61
136
if node .op == "get_attr" :
@@ -64,38 +139,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
64
139
elif node .op == "call_function" :
65
140
# skip ops if specified by user
66
141
node_target_name = getattr (node .target , "__name__" , "" ).lower ()
67
- if node_target_name in (self .skip_ops_for_coreml_delegation or []):
68
- self .log_once (
69
- "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
70
- + node_target_name
71
- )
72
- assert (
73
- not self .lower_full_graph
74
- ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
75
- return False
76
142
77
- # TODO: enable this after bugs in ExecuTorch's partitioner are fixed
78
- # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
79
- # # in the placeholders due to partitioning, which CoreML does not support
80
- # if not self.lower_full_graph and any(
81
- # isinstance(arg, torch.fx.Node)
82
- # and isinstance(
83
- # arg.meta.get("val", None),
84
- # (torch.SymInt, torch.SymBool, torch.SymFloat),
85
- # )
86
- # for arg in node.args
87
- # ):
88
- # self.log_once(
89
- # "Skipping op for CoreML delegation because it contains symbolic args: "
90
- # + node_target_name
91
- # )
92
- # assert not self.lower_full_graph
93
- # return False
143
+ if self .should_skip_op_for_delegation (node_target_name ):
144
+ return False
94
145
95
146
# query coremltools to see if node is supported
96
147
is_supported = ct .converters .mil .frontend .torch .is_torch_fx_node_supported (
97
148
node
98
149
)
150
+ if self .should_override_support (node ):
151
+ is_supported = False
152
+
99
153
if not is_supported :
100
154
if self .lower_full_graph :
101
155
raise NotImplementedError (
@@ -126,7 +180,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
126
180
127
181
128
182
class CoreMLPartitioner (Partitioner ):
129
-
130
183
def __init__ (
131
184
self ,
132
185
* ,
0 commit comments