Skip to content

Commit 10a4fb9

Browse files
committed
up
1 parent 66c928a commit 10a4fb9

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,15 @@ def log_once(self, msg: str) -> None:
5959

6060
def should_skip_op_for_delegation(self, node_target_name: str) -> bool:
6161
skipped_ops = self.skip_ops_for_coreml_delegation or []
62-
6362
if node_target_name in skipped_ops:
63+
assert (
64+
not self.lower_full_graph
65+
), f"Cannot skip {node_target_name} because lower_full_graph is True. Please set skip_ops_for_coreml_delegation=None or lower_full_graph=False in the CoreMLPartitioner"
66+
self.log_once(
67+
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
68+
+ node_target_name
69+
)
6470
return True
65-
66-
# For backwards compatibility
67-
split_name = node_target_name.split("::")
68-
if len(split_name) == 2:
69-
namespace, name_without_namespace = split_name
70-
if namespace == "aten" and name_without_namespace in skipped_ops:
71-
return True
72-
7371
return False
7472

7573
def should_override_support(self, node) -> bool:
@@ -83,9 +81,10 @@ def should_override_support(self, node) -> bool:
8381
exir_ops.edge.aten.add.Tensor,
8482
]
8583
and "alpha" in node.kwargs
84+
and node.kwargs["alpha"] != 1
8685
):
8786
self.log_once(
88-
"torch.ops.aten.{sub, add}.Tensor with alpha is not supported by CoreML. Overriding support."
87+
"torch.ops.aten.{sub, add}.Tensor with alpha != 1 is not supported by CoreML. Overriding support."
8988
)
9089
return True
9190

@@ -139,16 +138,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
139138
# check if the PyTorch op get called is supported in Core ML
140139
elif node.op == "call_function":
141140
# skip ops if specified by user
142-
node_target_name = node.target.name().lower()
141+
node_target_name = getattr(node.target, "__name__", "").lower()
143142

144143
if self.should_skip_op_for_delegation(node_target_name):
145-
self.log_once(
146-
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
147-
+ node_target_name
148-
)
149-
assert (
150-
not self.lower_full_graph
151-
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
152144
return False
153145

154146
# query coremltools to see if node is supported

0 commit comments

Comments
 (0)