Skip to content

Commit 7248122

Browse files
committed
init
1 parent 6c51181 commit 7248122

File tree

2 files changed

+190
-5
lines changed

2 files changed

+190
-5
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,20 @@
2828

2929
class OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
3030
def __init__(
31-
self, skip_ops_for_coreml_delegation: Optional[List[str]] = None
31+
self,
32+
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
33+
lower_full_graph: bool = False,
3234
) -> None:
3335
if skip_ops_for_coreml_delegation is None:
3436
skip_ops_for_coreml_delegation = []
3537
super().__init__()
3638
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
39+
self.lower_full_graph = lower_full_graph
40+
if self.lower_full_graph:
41+
assert (
42+
len(self.skip_ops_for_coreml_delegation or []) == 0
43+
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
44+
self._logged_skips = set()
3745

3846
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3947
# get_attr node can always be supported on any backend
@@ -44,14 +52,74 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4452
# skip ops if specified by user
4553
node_target_name = getattr(node.target, "__name__", "").lower()
4654
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
55+
skip_str = (
56+
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
57+
+ node_target_name
58+
)
59+
if skip_str not in self._logged_skips:
60+
logging.info(skip_str)
61+
self._logged_skips.add(skip_str)
62+
assert (
63+
not self.lower_full_graph
64+
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
65+
return False
66+
67+
# If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
68+
# in the placeholders due to partitioning, which CoreML does not support
69+
if not self.lower_full_graph and any(
70+
isinstance(arg, torch.fx.Node)
71+
and isinstance(
72+
arg.meta.get("val", None),
73+
(torch.SymInt, torch.SymBool, torch.SymFloat),
74+
)
75+
for arg in node.args
76+
):
77+
skip_str = (
78+
"Skipping op for CoreML delegation because it contains symbolic args: "
79+
+ node_target_name
80+
)
81+
if skip_str not in self._logged_skips:
82+
logging.info(skip_str)
83+
self._logged_skips.add(skip_str)
84+
assert not self.lower_full_graph
4785
return False
86+
4887
# query coremltools to see if node is supported
49-
return ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
88+
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
89+
node
90+
)
91+
if not is_supported:
92+
if self.lower_full_graph:
93+
raise NotImplementedError(
94+
f"""CoreML does not support the op {node_target_name}, but you have set lower_full_graph=True in the CoreMLPartitioner.
95+
96+
Please set lower_full_graph=False in the CoreMLPartitioner to allow running unsupported ops outside of CoreML. Note that setting lower_full_graph=False may affect performance of CoreML and the available features.
97+
As an alternative to setting lower_full_graph=False, you can try rewriting your model to avoid using this op.
98+
99+
Also consider filing an issue with Apple's coremltools repo to request support for the op: https://github.com/apple/coremltools/issues
100+
Do not file an issue with ExecuTorch for op support.
101+
"""
102+
)
103+
skip_str = (
104+
"Skipping op for CoreML delegation because it is not supported by CoreML: "
105+
+ node_target_name
106+
)
107+
if skip_str not in self._logged_skips:
108+
logging.info(skip_str)
109+
self._logged_skips.add(skip_str)
110+
return is_supported
50111
# cowardly refuse to support all other types of node:
51112
# 1. placeholder / output nodes should not be tagged
52113
# reference: https://github.com/pytorch/executorch/pull/1398
53114
# 2. call_module / call_method should have been replaced with call_function?
54115
else:
116+
skip_str = (
117+
"Skipping op for CoreML delegation because it is not get_attr or call_function: "
118+
+ node.op
119+
)
120+
if skip_str not in self._logged_skips:
121+
logging.info(skip_str)
122+
self._logged_skips.add(skip_str)
55123
return False
56124

57125

@@ -62,6 +130,8 @@ def __init__(
62130
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
63131
compile_specs: Optional[List[CompileSpec]] = None,
64132
take_over_mutable_buffer: Optional[bool] = True,
133+
lower_full_graph: bool = False,
134+
tag_constant_data: bool = True,
65135
) -> None:
66136
if skip_ops_for_coreml_delegation is None:
67137
skip_ops_for_coreml_delegation = []
@@ -71,6 +141,8 @@ def __init__(
71141
compile_specs=compile_specs if compile_specs is not None else [],
72142
)
73143
self.take_over_mutable_buffer = take_over_mutable_buffer
144+
self.lower_full_graph = lower_full_graph
145+
self.tag_constant_data = tag_constant_data
74146

75147
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
76148
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -80,7 +152,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
80152

81153
capability_partitioner = CapabilityBasedPartitioner(
82154
exported_program.graph_module,
83-
OperatorsSupportedForCoreMLBackend(self.skip_ops_for_coreml_delegation),
155+
OperatorsSupportedForCoreMLBackend(
156+
self.skip_ops_for_coreml_delegation, self.lower_full_graph
157+
),
84158
allows_single_node_partition=True,
85159
)
86160
partition_list = capability_partitioner.propose_partitions()
@@ -90,7 +164,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
90164
node.meta["delegation_tag"] = tag
91165
partition_tags[tag] = self.delegation_spec
92166

93-
tag_constant_data(exported_program)
167+
if self.tag_constant_data:
168+
tag_constant_data(exported_program)
94169
if self.take_over_mutable_buffer:
95170
logger.info(
96171
"Core ML partitioner will take over torch mutable buffer as Core ML state, "
@@ -109,7 +184,9 @@ def ops_to_not_decompose(
109184
self, ep: ExportedProgram
110185
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
111186
do_not_decompose = []
112-
op_support = OperatorsSupportedForCoreMLBackend()
187+
op_support = OperatorsSupportedForCoreMLBackend(
188+
self.skip_ops_for_coreml_delegation, self.lower_full_graph
189+
)
113190
_logged_warnings = set()
114191

115192
# CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

5+
import copy
56
import unittest
67

78
import coremltools as ct
@@ -16,6 +17,24 @@
1617
from executorch.exir.backend.utils import format_delegated_graph
1718

1819

20+
@torch.library.custom_op("unsupported::linear", mutates_args=())
21+
def _(
22+
x: torch.Tensor,
23+
w: torch.Tensor,
24+
b: torch.Tensor,
25+
) -> torch.Tensor:
26+
return torch.ops.aten.linear.default(x, w, b)
27+
28+
29+
@torch.library.register_fake("unsupported::linear")
30+
def _(
31+
x: torch.Tensor,
32+
w: torch.Tensor,
33+
b: torch.Tensor,
34+
) -> torch.Tensor:
35+
return torch.ops.aten.linear.default(x, w, b)
36+
37+
1938
class TestCoreMLPartitioner(unittest.TestCase):
2039
edge_compile_config = executorch.exir.EdgeCompileConfig()
2140

@@ -200,10 +219,99 @@ def forward(self, q, k_val, input_pos):
200219
"getitem",
201220
]
202221

222+
def test_lower_full_graph(self):
223+
class Model(torch.nn.Module):
224+
def forward(self, a, x, b):
225+
out = torch.ops.aten.linear.default(a, x, b)
226+
out2 = torch.ops.unsupported.linear.default(out, x, b)
227+
return out2
228+
229+
model = Model()
230+
model.eval()
231+
232+
example_inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
233+
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
234+
edge_program_manager = executorch.exir.to_edge(exir_program_aten)
235+
edge_program_manager2 = copy.deepcopy(edge_program_manager)
236+
237+
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
238+
239+
print(delegated_program_manager.exported_program())
240+
241+
for node in delegated_program_manager.exported_program().graph.nodes:
242+
if node.op == "call_function":
243+
assert node.target.__name__ in [
244+
"unsupported.linear.default",
245+
"executorch_call_delegate",
246+
"getitem",
247+
], node.target.__name__
248+
249+
with self.assertRaises(NotImplementedError):
250+
edge_program_manager2.to_backend(CoreMLPartitioner(lower_full_graph=True))
251+
252+
def test_symint_arg(self):
253+
class Model(torch.nn.Module):
254+
def forward(self, x, w, b, y):
255+
val = y.item()
256+
out = torch.ops.unsupported.linear.default(x, w, b + val) + val
257+
out2 = torch.ops.aten.linear.default(out, w, b) + val
258+
return out2
259+
260+
model = Model()
261+
model.eval()
262+
example_inputs = (
263+
torch.randn(2, 2),
264+
torch.randn(2, 2),
265+
torch.randn(2, 2),
266+
torch.tensor(2),
267+
)
268+
exir_program_aten = torch.export.export(model, example_inputs)
269+
270+
edge_program_manager = executorch.exir.to_edge(exir_program_aten)
271+
272+
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
273+
274+
# This op has symbolic args
275+
assert (
276+
"torch.ops.aten.scalar_tensor.default"
277+
in delegated_program_manager.exported_program().graph_module.code
278+
)
279+
280+
def test_tag_constant_data_false(self):
281+
class Model(torch.nn.Module):
282+
def __init__(self):
283+
super().__init__()
284+
self.linear = torch.nn.Linear(2, 2)
285+
286+
def forward(self, x):
287+
return self.linear(x)
288+
289+
model = Model()
290+
model.eval()
291+
example_inputs = (torch.randn(2, 2),)
292+
exir_program_aten = torch.export.export(model, example_inputs)
293+
294+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
295+
exir_program_aten,
296+
partitioner=[CoreMLPartitioner(tag_constant_data=False)],
297+
)
298+
for node in edge_program_manager.exported_program().graph.nodes:
299+
if (
300+
node.op == "call_function"
301+
and node.target.__name__ == "executorch_call_delegate"
302+
):
303+
break
304+
305+
# lowered_module_0, x, p_linear_weight, p_linear_bias
306+
assert len(node.args) == 4
307+
203308

204309
if __name__ == "__main__":
205310
test_runner = TestCoreMLPartitioner()
206311
test_runner.test_add_sub_skip_mm()
207312
test_runner.test_vit_skip_conv()
208313
test_runner.test_ops_to_not_decompose()
209314
test_runner.test_buffer()
315+
test_runner.test_lower_full_graph()
316+
test_runner.test_symint_arg()
317+
test_runner.test_tag_constant_data_false()

0 commit comments

Comments
 (0)