Skip to content

Commit 3a58575

Browse files
authored
Add call stack inspection to show warning only for deprecated to_edge workflow
Following the pattern from PR #13209 (XNNPACK), this commit: 1. Adds `import inspect` to enable call stack inspection 2. Adds `_check_if_called_from_to_backend()` helper method that: - Returns False if called from to_edge_transform_and_lower (recommended flow) - Returns True if called from to_backend in deprecated flow 3. Wraps the deprecation warning in a conditional check This ensures the warning only appears when using the deprecated to_edge() + to_backend() workflow, not when using the recommended to_edge_transform_and_lower() flow. Unit tests will be added in a separate commit.
1 parent d99c31b commit 3a58575

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55
import logging
6+
import inspect
67
from typing import Callable, List, Optional, Tuple
78

89
import coremltools as ct
@@ -222,15 +223,38 @@ def __init__(
222223
self.take_over_mutable_buffer
223224
), "When lower_full_graph=True, you must set take_over_mutable_buffer=True"
224225

226+
227+
def _check_if_called_from_to_backend(self) -> bool:
228+
"""
229+
Check if the partition method is being called from the deprecated
230+
to_backend workflow.
231+
232+
Returns True if called from deprecated direct to_backend, False if called
233+
from to_edge_transform_and_lower.
234+
"""
235+
stack = inspect.stack()
236+
237+
for frame_info in stack:
238+
if frame_info.function == "to_edge_transform_and_lower":
239+
return False
240+
241+
for frame_info in stack:
242+
if frame_info.function == "to_backend":
243+
filename = frame_info.filename
244+
if "program/_program.py" in filename:
245+
return True
246+
return False
225247
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
226248
# Run the CapabilityBasedPartitioner to return the largest possible
227249
# subgraphs containing the nodes with the tags
228250
logger.info("CoreMLPartitioner::partition")
229-
logger.warning("Using the old `to_edge()` flow with CoreML may result in performance regression. "
230-
"The recommended flow is to use `to_edge_transform_and_lower()` with the CoreML partitioner. "
231-
"See the documentation for more details: "
232-
"https://github.com/pytorch/executorch/blob/main/docs/source/backends/coreml/coreml-overview.md#using-the-core-ml-backend"
233-
)
251+
# Check if we're being called from the deprecated to_backend workflow
252+
if self._check_if_called_from_to_backend():
253+
logger.warning("Using the old `to_edge()` flow with CoreML may result in performance regression. "
254+
"The recommended flow is to use `to_edge_transform_and_lower()` with the CoreML partitioner. "
255+
"See the documentation for more details: "
256+
"https://github.com/pytorch/executorch/blob/main/docs/source/backends/coreml/coreml-overview.md#using-the-core-ml-backend"
257+
)
234258
partition_tags = {}
235259

236260
capability_partitioner = CapabilityBasedPartitioner(

0 commit comments

Comments
 (0)