2020 PartitionResult ,
2121)
2222from executorch .exir .backend .utils import tag_constant_data , tag_mutated_buffer
23+ from executorch .exir .dialects ._ops import ops as exir_ops
2324from torch .export .exported_program import ExportedProgram
2425from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
2526from torch .fx .passes .operator_support import OperatorSupportBase
@@ -56,6 +57,80 @@ def log_once(self, msg: str) -> None:
5657 logger .info (msg )
5758 self ._logged_msgs .add (msg )
5859
60+ def should_skip_op_for_delegation (self , node_target_name : str ) -> bool :
61+ skipped_ops = self .skip_ops_for_coreml_delegation or []
62+ if node_target_name in skipped_ops :
63+ assert (
64+ not self .lower_full_graph
65+ ), f"Cannot skip { node_target_name } because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner"
66+ self .log_once (
67+ "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
68+ + node_target_name
69+ )
70+ return True
71+ return False
72+
73+ def should_override_support (self , node ) -> bool :
74+ # https://github.com/apple/coremltools/issues/2573
75+ if (
76+ node .target
77+ in [
78+ torch .ops .aten .sub .Tensor ,
79+ exir_ops .edge .aten .sub .Tensor ,
80+ torch .ops .aten .add .Tensor ,
81+ exir_ops .edge .aten .add .Tensor ,
82+ ]
83+ and "alpha" in node .kwargs
84+ and node .kwargs ["alpha" ] != 1
85+ ):
86+ self .log_once (
87+ "torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support."
88+ )
89+ return True
90+
91+ # https://github.com/apple/coremltools/issues/2565
92+ if node .target in [
93+ torch .ops .aten .diagonal .default ,
94+ torch .ops .aten .diagonal_copy .default ,
95+ exir_ops .edge .aten .diagonal .default ,
96+ exir_ops .edge .aten .diagonal_copy .default ,
97+ ]:
98+ self .log_once (
99+ "torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support."
100+ )
101+ return True
102+
103+ # https://github.com/apple/coremltools/issues/2569
104+ if node .target in [
105+ torch .ops .aten .acosh .default ,
106+ exir_ops .edge .aten .acosh .default ,
107+ torch .ops .aten .asinh .default ,
108+ exir_ops .edge .aten .asinh .default ,
109+ ]:
110+ self .log_once (
111+ "torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support."
112+ )
113+ return True
114+
115+ # TODO: enable this after bugs in ExecuTorch's partitioner are fixed
116+ # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
117+ # # in the placeholders due to partitioning, which CoreML does not support
118+ # if not self.lower_full_graph and any(
119+ # isinstance(arg, torch.fx.Node)
120+ # and isinstance(
121+ # arg.meta.get("val", None),
122+ # (torch.SymInt, torch.SymBool, torch.SymFloat),
123+ # )
124+ # for arg in node.args
125+ # ):
126+ # self.log_once(
127+ # "Skipping op for CoreML delegation because it contains symbolic args: "
128+ # + node_target_name
129+ # )
130+ # return True
131+
132+ return False
133+
59134 def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
60135 # get_attr node can always be supported on any backend
61136 if node .op == "get_attr" :
@@ -64,38 +139,17 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
64139 elif node .op == "call_function" :
65140 # skip ops if specified by user
66141 node_target_name = getattr (node .target , "__name__" , "" ).lower ()
67- if node_target_name in (self .skip_ops_for_coreml_delegation or []):
68- self .log_once (
69- "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
70- + node_target_name
71- )
72- assert (
73- not self .lower_full_graph
74- ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
75- return False
76142
77- # TODO: enable this after bugs in ExecuTorch's partitioner are fixed
78- # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
79- # # in the placeholders due to partitioning, which CoreML does not support
80- # if not self.lower_full_graph and any(
81- # isinstance(arg, torch.fx.Node)
82- # and isinstance(
83- # arg.meta.get("val", None),
84- # (torch.SymInt, torch.SymBool, torch.SymFloat),
85- # )
86- # for arg in node.args
87- # ):
88- # self.log_once(
89- # "Skipping op for CoreML delegation because it contains symbolic args: "
90- # + node_target_name
91- # )
92- # assert not self.lower_full_graph
93- # return False
143+ if self .should_skip_op_for_delegation (node_target_name ):
144+ return False
94145
95146 # query coremltools to see if node is supported
96147 is_supported = ct .converters .mil .frontend .torch .is_torch_fx_node_supported (
97148 node
98149 )
150+ if self .should_override_support (node ):
151+ is_supported = False
152+
99153 if not is_supported :
100154 if self .lower_full_graph :
101155 raise NotImplementedError (
@@ -126,7 +180,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
126180
127181
128182class CoreMLPartitioner (Partitioner ):
129-
130183 def __init__ (
131184 self ,
132185 * ,
0 commit comments