Skip to content

Commit 17e6dcf

Browse files
committed
up
1 parent 7248122 commit 17e6dcf

File tree

2 files changed

+94
-72
lines changed

2 files changed

+94
-72
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ def __init__(
4141
assert (
4242
len(self.skip_ops_for_coreml_delegation or []) == 0
4343
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
44-
self._logged_skips = set()
44+
self._logged_msgs = set()
45+
46+
def log_once(self, msg: str) -> None:
47+
if msg not in self._logged_msgs:
48+
logging.info(msg)
49+
self._logged_msgs.add(msg)
4550

4651
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4752
# get_attr node can always be supported on any backend
@@ -52,37 +57,32 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5257
# skip ops if specified by user
5358
node_target_name = getattr(node.target, "__name__", "").lower()
5459
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
55-
skip_str = (
60+
self.log_once(
5661
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
5762
+ node_target_name
5863
)
59-
if skip_str not in self._logged_skips:
60-
logging.info(skip_str)
61-
self._logged_skips.add(skip_str)
6264
assert (
6365
not self.lower_full_graph
6466
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
6567
return False
6668

67-
# If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
68-
# in the placeholders due to partitioning, which CoreML does not support
69-
if not self.lower_full_graph and any(
70-
isinstance(arg, torch.fx.Node)
71-
and isinstance(
72-
arg.meta.get("val", None),
73-
(torch.SymInt, torch.SymBool, torch.SymFloat),
74-
)
75-
for arg in node.args
76-
):
77-
skip_str = (
78-
"Skipping op for CoreML delegation because it contains symbolic args: "
79-
+ node_target_name
80-
)
81-
if skip_str not in self._logged_skips:
82-
logging.info(skip_str)
83-
self._logged_skips.add(skip_str)
84-
assert not self.lower_full_graph
85-
return False
69+
# TODO: enable this after bugs in to_edge_transform_and_lower are fixed
70+
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
71+
# # in the placeholders due to partitioning, which CoreML does not support
72+
# if not self.lower_full_graph and any(
73+
# isinstance(arg, torch.fx.Node)
74+
# and isinstance(
75+
# arg.meta.get("val", None),
76+
# (torch.SymInt, torch.SymBool, torch.SymFloat),
77+
# )
78+
# for arg in node.args
79+
# ):
80+
# self.log_once(
81+
# "Skipping op for CoreML delegation because it contains symbolic args: "
82+
# + node_target_name
83+
# )
84+
# assert not self.lower_full_graph
85+
# return False
8686

8787
# query coremltools to see if node is supported
8888
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
@@ -100,26 +100,20 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
100100
Do not file an issue with ExecuTorch for op support.
101101
"""
102102
)
103-
skip_str = (
103+
self.log_once(
104104
"Skipping op for CoreML delegation because it is not supported by CoreML: "
105105
+ node_target_name
106106
)
107-
if skip_str not in self._logged_skips:
108-
logging.info(skip_str)
109-
self._logged_skips.add(skip_str)
110107
return is_supported
111108
# cowardly refuse to support all other types of node:
112109
# 1. placeholder / output nodes should not be tagged
113110
# reference: https://github.com/pytorch/executorch/pull/1398
114111
# 2. call_module / call_method should have been replaced with call_function?
115112
else:
116-
skip_str = (
113+
self.log_once(
117114
"Skipping op for CoreML delegation because it is not get_attr or call_function: "
118115
+ node.op
119116
)
120-
if skip_str not in self._logged_skips:
121-
logging.info(skip_str)
122-
self._logged_skips.add(skip_str)
123117
return False
124118

125119

@@ -131,7 +125,7 @@ def __init__(
131125
compile_specs: Optional[List[CompileSpec]] = None,
132126
take_over_mutable_buffer: Optional[bool] = True,
133127
lower_full_graph: bool = False,
134-
tag_constant_data: bool = True,
128+
take_over_constant_data: bool = True,
135129
) -> None:
136130
if skip_ops_for_coreml_delegation is None:
137131
skip_ops_for_coreml_delegation = []
@@ -142,7 +136,7 @@ def __init__(
142136
)
143137
self.take_over_mutable_buffer = take_over_mutable_buffer
144138
self.lower_full_graph = lower_full_graph
145-
self.tag_constant_data = tag_constant_data
139+
self.take_over_constant_data = take_over_constant_data
146140

147141
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
148142
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -164,7 +158,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
164158
node.meta["delegation_tag"] = tag
165159
partition_tags[tag] = self.delegation_spec
166160

167-
if self.tag_constant_data:
161+
if self.take_over_constant_data:
168162
tag_constant_data(exported_program)
169163
if self.take_over_mutable_buffer:
170164
logger.info(
@@ -215,5 +209,4 @@ def ops_to_not_decompose(
215209
if warn_str not in _logged_warnings:
216210
logger.warning(warn_str)
217211
_logged_warnings.add(warn_str)
218-
219212
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55
import copy
6+
import sys
67
import unittest
78

89
import coremltools as ct
@@ -15,6 +16,7 @@
1516
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1617
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1718
from executorch.exir.backend.utils import format_delegated_graph
19+
from executorch.runtime import Runtime
1820

1921

2022
@torch.library.custom_op("unsupported::linear", mutates_args=())
@@ -35,6 +37,10 @@ def _(
3537
return torch.ops.aten.linear.default(x, w, b)
3638

3739

40+
_TEST_RUNTIME = sys.platform == "darwin"
41+
_TEST_RUNTIME = False # Disable until segfault fixed: https://github.com/pytorch/executorch/issues/12408
42+
43+
3844
class TestCoreMLPartitioner(unittest.TestCase):
3945
edge_compile_config = executorch.exir.EdgeCompileConfig()
4046

@@ -236,8 +242,6 @@ def forward(self, a, x, b):
236242

237243
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
238244

239-
print(delegated_program_manager.exported_program())
240-
241245
for node in delegated_program_manager.exported_program().graph.nodes:
242246
if node.op == "call_function":
243247
assert node.target.__name__ in [
@@ -249,51 +253,62 @@ def forward(self, a, x, b):
249253
with self.assertRaises(NotImplementedError):
250254
edge_program_manager2.to_backend(CoreMLPartitioner(lower_full_graph=True))
251255

252-
def test_symint_arg(self):
253-
class Model(torch.nn.Module):
254-
def forward(self, x, w, b, y):
255-
val = y.item()
256-
out = torch.ops.unsupported.linear.default(x, w, b + val) + val
257-
out2 = torch.ops.aten.linear.default(out, w, b) + val
258-
return out2
259-
260-
model = Model()
261-
model.eval()
262-
example_inputs = (
263-
torch.randn(2, 2),
264-
torch.randn(2, 2),
265-
torch.randn(2, 2),
266-
torch.tensor(2),
267-
)
268-
exir_program_aten = torch.export.export(model, example_inputs)
269-
270-
edge_program_manager = executorch.exir.to_edge(exir_program_aten)
271-
272-
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
273-
274-
# This op has symbolic args
275-
assert (
276-
"torch.ops.aten.scalar_tensor.default"
277-
in delegated_program_manager.exported_program().graph_module.code
278-
)
279-
280-
def test_tag_constant_data_false(self):
256+
# def test_symint_arg(self):
257+
# class Model(torch.nn.Module):
258+
# def forward(self, x, w, b, y):
259+
# val = y.item()
260+
# torch._check(val >= 0)
261+
# torch._check(val < 2)
262+
# out = torch.ops.aten.linear.default(x, w, b)
263+
# out2 = out.relu()[val]
264+
# return out2
265+
266+
# model = Model()
267+
# model.eval()
268+
# example_inputs = (
269+
# torch.randn(2, 2),
270+
# torch.randn(2, 2),
271+
# torch.randn(2, 2),
272+
# torch.tensor(2),
273+
# )
274+
# exir_program_aten = torch.export.export(model, example_inputs)
275+
276+
# edge_program_manager = executorch.exir.to_edge(exir_program_aten)
277+
278+
# delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner(skip_ops_for_coreml_delegation=["aten.scalar_tensor.default"]))
279+
280+
# # This op has symbolic args
281+
# assert (
282+
# "torch.ops.aten._assert_scalar.default"
283+
# in delegated_program_manager.exported_program().graph_module.code
284+
# )
285+
286+
# if _TEST_RUNTIME:
287+
# et_prog = delegated_program_manager.to_executorch()
288+
# runtime = Runtime.get()
289+
# program = runtime.load_program(et_prog.buffer)
290+
# method = program.load_method("forward")
291+
# et_outputs = method.execute(*example_inputs)[0]
292+
# eager_outputs = model(*example_inputs)
293+
# self.assertTrue(torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02))
294+
295+
def test_take_over_constant_data_false(self):
281296
class Model(torch.nn.Module):
282297
def __init__(self):
283298
super().__init__()
284-
self.linear = torch.nn.Linear(2, 2)
299+
self.linear = torch.nn.Linear(50, 100)
285300

286301
def forward(self, x):
287302
return self.linear(x)
288303

289304
model = Model()
290305
model.eval()
291-
example_inputs = (torch.randn(2, 2),)
306+
example_inputs = (torch.randn(2, 50),)
292307
exir_program_aten = torch.export.export(model, example_inputs)
293308

294309
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
295310
exir_program_aten,
296-
partitioner=[CoreMLPartitioner(tag_constant_data=False)],
311+
partitioner=[CoreMLPartitioner(take_over_constant_data=False)],
297312
)
298313
for node in edge_program_manager.exported_program().graph.nodes:
299314
if (
@@ -305,6 +320,20 @@ def forward(self, x):
305320
# lowered_module_0, x, p_linear_weight, p_linear_bias
306321
assert len(node.args) == 4
307322

323+
if _TEST_RUNTIME:
324+
et_prog = edge_program_manager.to_executorch()
325+
runtime = Runtime.get()
326+
program = runtime.load_program(et_prog.buffer)
327+
method = program.load_method("forward")
328+
et_outputs = method.execute(*example_inputs)[0]
329+
eager_outputs = model(*example_inputs)
330+
self.assertTrue(
331+
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
332+
)
333+
334+
with open("/tmp/et_model.pte", "wb") as file:
335+
et_prog.write_to_file(file)
336+
308337

309338
if __name__ == "__main__":
310339
test_runner = TestCoreMLPartitioner()
@@ -313,5 +342,5 @@ def forward(self, x):
313342
test_runner.test_ops_to_not_decompose()
314343
test_runner.test_buffer()
315344
test_runner.test_lower_full_graph()
316-
test_runner.test_symint_arg()
317-
test_runner.test_tag_constant_data_false()
345+
# test_runner.test_symint_arg()
346+
test_runner.test_take_over_constant_data_false()

0 commit comments

Comments
 (0)