2828
2929class OperatorsSupportedForCoreMLBackend (OperatorSupportBase ):
3030 def __init__ (
31- self , skip_ops_for_coreml_delegation : Optional [List [str ]] = None
31+ self ,
32+ skip_ops_for_coreml_delegation : Optional [List [str ]] = None ,
33+ lower_full_graph : bool = False ,
3234 ) -> None :
3335 if skip_ops_for_coreml_delegation is None :
3436 skip_ops_for_coreml_delegation = []
3537 super ().__init__ ()
3638 self .skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
39+ self .lower_full_graph = lower_full_graph
40+ if self .lower_full_graph :
41+ assert (
42+ len (self .skip_ops_for_coreml_delegation or []) == 0
43+ ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
44+ self ._logged_skips = set ()
3745
3846 def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
3947 # get_attr node can always be supported on any backend
@@ -44,14 +52,74 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4452 # skip ops if specified by user
4553 node_target_name = getattr (node .target , "__name__" , "" ).lower ()
4654 if node_target_name in (self .skip_ops_for_coreml_delegation or []):
55+ skip_str = (
56+ "Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
57+ + node_target_name
58+ )
59+ if skip_str not in self ._logged_skips :
60+ logging .info (skip_str )
61+ self ._logged_skips .add (skip_str )
62+ assert (
63+ not self .lower_full_graph
64+ ), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
65+ return False
66+
67+ # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
68+ # in the placeholders due to partitioning, which CoreML does not support
69+ if not self .lower_full_graph and any (
70+ isinstance (arg , torch .fx .Node )
71+ and isinstance (
72+ arg .meta .get ("val" , None ),
73+ (torch .SymInt , torch .SymBool , torch .SymFloat ),
74+ )
75+ for arg in node .args
76+ ):
77+ skip_str = (
78+ "Skipping op for CoreML delegation because it contains symbolic args: "
79+ + node_target_name
80+ )
81+ if skip_str not in self ._logged_skips :
82+ logging .info (skip_str )
83+ self ._logged_skips .add (skip_str )
84+ assert not self .lower_full_graph
4785 return False
86+
4887 # query coremltools to see if node is supported
49- return ct .converters .mil .frontend .torch .is_torch_fx_node_supported (node )
88+ is_supported = ct .converters .mil .frontend .torch .is_torch_fx_node_supported (
89+ node
90+ )
91+ if not is_supported :
92+ if self .lower_full_graph :
93+ raise NotImplementedError (
94+ f"""CoreML does not support the op { node_target_name } , but you have set lower_full_graph=True in the CoreMLPartitioner.
95+
96+ Please set lower_full_graph=False in the CoreMLPartitioner to allow running unsupported ops outside of CoreML. Note that setting lower_full_graph=False may affect performance of CoreML and the available features.
97+ As an alternative to setting lower_full_graph=False, you can try rewriting your model to avoid using this op.
98+
99+ Also consider filing an issue with Apple's coremltools repo to request support for the op: https://github.com/apple/coremltools/issues
100+ Do not file an issue with ExecuTorch for op support.
101+ """
102+ )
103+ skip_str = (
104+ "Skipping op for CoreML delegation because it is not supported by CoreML: "
105+ + node_target_name
106+ )
107+ if skip_str not in self ._logged_skips :
108+ logging .info (skip_str )
109+ self ._logged_skips .add (skip_str )
110+ return is_supported
50111 # cowardly refuse to support all other types of node:
51112 # 1. placeholder / output nodes should not be tagged
52113 # reference: https://github.com/pytorch/executorch/pull/1398
53114 # 2. call_module / call_method should have been replaced with call_function?
54115 else :
116+ skip_str = (
117+ "Skipping op for CoreML delegation because it is not get_attr or call_function: "
118+ + node .op
119+ )
120+ if skip_str not in self ._logged_skips :
121+ logging .info (skip_str )
122+ self ._logged_skips .add (skip_str )
55123 return False
56124
57125
@@ -62,6 +130,8 @@ def __init__(
62130 skip_ops_for_coreml_delegation : Optional [List [str ]] = None ,
63131 compile_specs : Optional [List [CompileSpec ]] = None ,
64132 take_over_mutable_buffer : Optional [bool ] = True ,
133+ lower_full_graph : bool = False ,
134+ tag_constant_data : bool = True ,
65135 ) -> None :
66136 if skip_ops_for_coreml_delegation is None :
67137 skip_ops_for_coreml_delegation = []
@@ -71,6 +141,8 @@ def __init__(
71141 compile_specs = compile_specs if compile_specs is not None else [],
72142 )
73143 self .take_over_mutable_buffer = take_over_mutable_buffer
144+ self .lower_full_graph = lower_full_graph
145+ self .tag_constant_data = tag_constant_data
74146
75147 def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
76148 # Run the CapabilityBasedPartitioner to return the largest possible
@@ -80,7 +152,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
80152
81153 capability_partitioner = CapabilityBasedPartitioner (
82154 exported_program .graph_module ,
83- OperatorsSupportedForCoreMLBackend (self .skip_ops_for_coreml_delegation ),
155+ OperatorsSupportedForCoreMLBackend (
156+ self .skip_ops_for_coreml_delegation , self .lower_full_graph
157+ ),
84158 allows_single_node_partition = True ,
85159 )
86160 partition_list = capability_partitioner .propose_partitions ()
@@ -90,7 +164,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
90164 node .meta ["delegation_tag" ] = tag
91165 partition_tags [tag ] = self .delegation_spec
92166
93- tag_constant_data (exported_program )
167+ if self .tag_constant_data :
168+ tag_constant_data (exported_program )
94169 if self .take_over_mutable_buffer :
95170 logger .info (
96171 "Core ML partitioner will take over torch mutable buffer as Core ML state, "
@@ -109,7 +184,9 @@ def ops_to_not_decompose(
109184 self , ep : ExportedProgram
110185 ) -> Tuple [List [torch ._ops .OpOverload ], Optional [Callable [[torch .fx .Node ], bool ]]]:
111186 do_not_decompose = []
112- op_support = OperatorsSupportedForCoreMLBackend ()
187+ op_support = OperatorsSupportedForCoreMLBackend (
188+ self .skip_ops_for_coreml_delegation , self .lower_full_graph
189+ )
113190 _logged_warnings = set ()
114191
115192 # CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace
0 commit comments