Skip to content

Commit 09d5255

Browse files
authored
Add deprecation warning for to_edge + to_backend workflow in XnnpackPartitioner (#13209)
### Summary This PR adds a deprecation warning in the XnnpackPartitioner to guide users away from the deprecated to_edge() + to_backend() workflow and toward the recommended to_edge_transform_and_lower() flow. We inspect the call stack in the partitioner to detect when the partitioner is called from the deprecated workflow and then print out a warning statement. This helps prevent issues that can arise from the deprecated workflow. ### Test plan Added tests testing the deprecation warning functionality in test_xnnpack_partitioner.py. Tests verify that the warning appears when using to_edge() + to_backend() and does not appear when using to_edge_transform_and_lower().
1 parent 4ce7078 commit 09d5255

File tree

2 files changed

+116
-1
lines changed

2 files changed

+116
-1
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import inspect
78
import itertools
8-
99
import logging
1010
from typing import List, Optional, Type, Union
1111

@@ -65,6 +65,37 @@ def __init__(
6565
self.per_op_mode = per_op_mode
6666
super().__init__(delegation_spec, initialized_configs)
6767

68+
def _check_if_called_from_to_backend(self) -> bool:
69+
"""
70+
Check if the partition method is being called from the deprecated to_backend workflow.
71+
Returns True if called from deprecated direct to_backend, False if called from to_edge_transform_and_lower.
72+
"""
73+
stack = inspect.stack()
74+
75+
for frame_info in stack:
76+
if frame_info.function == "to_edge_transform_and_lower":
77+
return False
78+
79+
for frame_info in stack:
80+
if frame_info.function == "to_backend":
81+
filename = frame_info.filename
82+
if "program/_program.py" in filename:
83+
return True
84+
return False
85+
86+
def partition(self, exported_program):
87+
"""
88+
Override partition to add deprecation warning when called from to_backend.
89+
"""
90+
# Check if we're being called from the deprecated to_backend workflow
91+
if self._check_if_called_from_to_backend():
92+
logger.warning(
93+
"\nDEPRECATION WARNING: You are using the deprecated 'to_edge() + to_backend()' workflow. "
94+
"Please consider migrating to 'to_edge_transform_and_lower()' for better error handling and optimization. "
95+
)
96+
97+
return super().partition(exported_program)
98+
6899
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
69100
"""
70101
generate_partitions is different if partitioner is set to per_op_mode
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import io
8+
import logging
9+
import unittest
10+
11+
import torch
12+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13+
from executorch.exir import to_edge, to_edge_transform_and_lower
14+
from torch.export import export
15+
16+
17+
class TestXnnpackPartitioner(unittest.TestCase):
18+
"""Test cases for XnnpackPartitioner functionality and deprecation warnings."""
19+
20+
class SimpleModel(torch.nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
self.linear = torch.nn.Linear(10, 5)
24+
25+
def forward(self, x):
26+
return self.linear(x)
27+
28+
def test_deprecation_warning_for_to_backend_workflow(self):
29+
"""
30+
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
31+
"""
32+
model = self.SimpleModel()
33+
x = torch.randn(1, 10)
34+
35+
exported_model = export(model, (x,))
36+
37+
# Capture log output to check for deprecation warning
38+
log_capture_string = io.StringIO()
39+
ch = logging.StreamHandler(log_capture_string)
40+
ch.setLevel(logging.WARNING)
41+
42+
logger = logging.getLogger(
43+
"executorch.backends.xnnpack.partition.xnnpack_partitioner"
44+
)
45+
logger.addHandler(ch)
46+
logger.setLevel(logging.WARNING)
47+
48+
edge = to_edge(exported_model)
49+
partitioner = XnnpackPartitioner()
50+
51+
edge.to_backend(partitioner)
52+
53+
log_contents = log_capture_string.getvalue()
54+
self.assertIn("DEPRECATION WARNING", log_contents)
55+
self.assertIn("to_edge() + to_backend()", log_contents)
56+
self.assertIn("to_edge_transform_and_lower()", log_contents)
57+
58+
def test_no_warning_for_to_edge_transform_and_lower_workflow(self):
59+
"""
60+
Test that the recommended to_edge_transform_and_lower workflow does NOT show a deprecation warning.
61+
"""
62+
63+
model = self.SimpleModel()
64+
x = torch.randn(1, 10)
65+
66+
exported_model = export(model, (x,))
67+
68+
# Capture log output to check for deprecation warning
69+
log_capture_string = io.StringIO()
70+
ch = logging.StreamHandler(log_capture_string)
71+
ch.setLevel(logging.WARNING)
72+
73+
logger = logging.getLogger(
74+
"executorch.backends.xnnpack.partition.xnnpack_partitioner"
75+
)
76+
logger.addHandler(ch)
77+
logger.setLevel(logging.WARNING)
78+
79+
partitioner = XnnpackPartitioner()
80+
81+
to_edge_transform_and_lower(exported_model, partitioner=[partitioner])
82+
83+
log_contents = log_capture_string.getvalue()
84+
self.assertNotIn("DEPRECATION WARNING", log_contents)

0 commit comments

Comments
 (0)