2020 PartitionResult ,
2121)
2222from executorch .exir .backend .utils import tag_constant_data , tag_mutated_buffer
23+ from executorch .exir .dialects ._ops import ops as exir_ops
2324from torch .export .exported_program import ExportedProgram
2425from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
2526from torch .fx .passes .operator_support import OperatorSupportBase
@@ -56,15 +57,67 @@ def log_once(self, msg: str) -> None:
5657 logger .info (msg )
5758 self ._logged_msgs .add (msg )
5859
60+ def should_skip_op_for_delegation (self , node_target_name : str ) -> bool :
61+ skipped_ops = self .skip_ops_for_coreml_delegation or []
62+
63+ if node_target_name in skipped_ops :
64+ return True
65+
66+ # For backwards compatibility
67+ split_name = node_target_name .split ("::" )
68+ if len (split_name ) == 2 :
69+ namespace , name_without_namespace = split_name
70+ if namespace == "aten" and name_without_namespace in skipped_ops :
71+ return True
72+
73+ return False
74+
75+ def should_override_support (self , node ) -> bool :
76+ # https://github.com/apple/coremltools/issues/2573
77+ if (
78+ node .target
79+ in [
80+ torch .ops .aten .sub .Tensor ,
81+ exir_ops .edge .aten .sub .Tensor ,
82+ torch .ops .aten .add .Tensor ,
83+ exir_ops .edge .aten .add .Tensor ,
84+ ]
85+ and "alpha" in node .kwargs
86+ ):
87+ self .log_once (
88+ "torch.ops.aten.{sub, add}.Tensor with alpha is not supported by CoreML. Overriding support."
89+ )
90+ return True
91+
92+ # TODO: enable this after bugs in ExecuTorch's partitioner are fixed
93+ # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
94+ # # in the placeholders due to partitioning, which CoreML does not support
95+ # if not self.lower_full_graph and any(
96+ # isinstance(arg, torch.fx.Node)
97+ # and isinstance(
98+ # arg.meta.get("val", None),
99+ # (torch.SymInt, torch.SymBool, torch.SymFloat),
100+ # )
101+ # for arg in node.args
102+ # ):
103+ # self.log_once(
104+ # "Skipping op for CoreML delegation because it contains symbolic args: "
105+ # + node_target_name
106+ # )
107+ # return True
108+
109+ return False
110+
59111 def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
60112 # get_attr node can always be supported on any backend
61113 if node .op == "get_attr" :
62114 return True
63115 # check if the PyTorch op get called is supported in Core ML
64116 elif node .op == "call_function" :
65117 # skip ops if specified by user
66- node_target_name = getattr (node .target , "__name__" , "" ).lower ()
67- if node_target_name in (self .skip_ops_for_coreml_delegation or []):
118+ node_target_name = node .target .name ().lower ()
119+
120+ if self .should_skip_op_for_delegation (node_target_name ):
68121 self .log_once (
69122 "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
70123 + node_target_name
@@ -74,28 +127,13 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
74127 ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
75128 return False
76129
77- # TODO: enable this after bugs in ExecuTorch's partitioner are fixed
78- # # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
79- # # in the placeholders due to partitioning, which CoreML does not support
80- # if not self.lower_full_graph and any(
81- # isinstance(arg, torch.fx.Node)
82- # and isinstance(
83- # arg.meta.get("val", None),
84- # (torch.SymInt, torch.SymBool, torch.SymFloat),
85- # )
86- # for arg in node.args
87- # ):
88- # self.log_once(
89- # "Skipping op for CoreML delegation because it contains symbolic args: "
90- # + node_target_name
91- # )
92- # assert not self.lower_full_graph
93- # return False
94-
95130 # query coremltools to see if node is supported
96131 is_supported = ct .converters .mil .frontend .torch .is_torch_fx_node_supported (
97132 node
98133 )
134+ if self .should_override_support (node ):
135+ is_supported = False
136+
99137 if not is_supported :
100138 if self .lower_full_graph :
101139 raise NotImplementedError (
@@ -126,7 +164,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
126164
127165
128166class CoreMLPartitioner (Partitioner ):
129-
130167 def __init__ (
131168 self ,
132169 * ,
0 commit comments