Skip to content

Commit 959e793

Browse files
committed
Add warning when using deprecated to_edge and to_backend methods
1 parent 5797608 commit 959e793

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import inspect
78
import itertools
8-
99
import logging
1010
from typing import List, Optional, Type, Union
1111

@@ -65,6 +65,37 @@ def __init__(
6565
self.per_op_mode = per_op_mode
6666
super().__init__(delegation_spec, initialized_configs)
6767

68+
def _check_if_called_from_to_backend(self) -> bool:
69+
"""
70+
Check if the partition method is being called from the deprecated to_backend workflow.
71+
Returns True if called from deprecated direct to_backend, False if called from to_edge_transform_and_lower.
72+
"""
73+
stack = inspect.stack()
74+
75+
for frame_info in stack:
76+
if frame_info.function == "to_edge_transform_and_lower":
77+
return False
78+
79+
for frame_info in stack:
80+
if frame_info.function == "to_backend":
81+
filename = frame_info.filename
82+
if "program/_program.py" in filename:
83+
return True
84+
return False
85+
86+
def partition(self, exported_program):
87+
"""
88+
Override partition to add deprecation warning when called from to_backend.
89+
"""
90+
# Check if we're being called from the deprecated to_backend workflow
91+
if self._check_if_called_from_to_backend():
92+
logger.warning(
93+
"\nDEPRECATION WARNING: You are using the deprecated 'to_edge() + to_backend()' workflow. "
94+
"Please consider migrating to 'to_edge_transform_and_lower()' for better error handling and optimization. "
95+
)
96+
97+
return super().partition(exported_program)
98+
6899
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
69100
"""
70101
generate_partitions is different if partitioner is set to per_op_mode
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
import logging
9+
import io
10+
11+
import torch
12+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13+
from executorch.exir import to_edge, to_edge_transform_and_lower
14+
from torch.export import export
15+
16+
17+
class TestXnnpackPartitioner(unittest.TestCase):
18+
"""Test cases for XnnpackPartitioner functionality and deprecation warnings."""
19+
class SimpleModel(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
self.linear = torch.nn.Linear(10, 5)
23+
24+
def forward(self, x):
25+
return self.linear(x)
26+
27+
def test_deprecation_warning_for_to_backend_workflow(self):
28+
"""
29+
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
30+
"""
31+
model = self.SimpleModel()
32+
x = torch.randn(1, 10)
33+
34+
exported_model = export(model, (x,))
35+
36+
# Capture log output to check for deprecation warning
37+
log_capture_string = io.StringIO()
38+
ch = logging.StreamHandler(log_capture_string)
39+
ch.setLevel(logging.WARNING)
40+
41+
logger = logging.getLogger('executorch.backends.xnnpack.partition.xnnpack_partitioner')
42+
logger.addHandler(ch)
43+
logger.setLevel(logging.WARNING)
44+
45+
edge = to_edge(exported_model)
46+
partitioner = XnnpackPartitioner()
47+
48+
edge.to_backend(partitioner)
49+
50+
log_contents = log_capture_string.getvalue()
51+
self.assertIn("DEPRECATION WARNING", log_contents)
52+
self.assertIn("to_edge() + to_backend()", log_contents)
53+
self.assertIn("to_edge_transform_and_lower()", log_contents)
54+
55+
def test_no_warning_for_to_edge_transform_and_lower_workflow(self):
56+
"""
57+
Test that the recommended to_edge_transform_and_lower workflow does NOT show a deprecation warning.
58+
"""
59+
60+
model = self.SimpleModel()
61+
x = torch.randn(1, 10)
62+
63+
exported_model = export(model, (x,))
64+
65+
# Capture log output to check for deprecation warning
66+
log_capture_string = io.StringIO()
67+
ch = logging.StreamHandler(log_capture_string)
68+
ch.setLevel(logging.WARNING)
69+
70+
logger = logging.getLogger('executorch.backends.xnnpack.partition.xnnpack_partitioner')
71+
logger.addHandler(ch)
72+
logger.setLevel(logging.WARNING)
73+
74+
partitioner = XnnpackPartitioner()
75+
76+
to_edge_transform_and_lower(
77+
exported_model,
78+
partitioner=[partitioner]
79+
)
80+
81+
log_contents = log_capture_string.getvalue()
82+
self.assertNotIn("DEPRECATION WARNING", log_contents)

0 commit comments

Comments
 (0)