Skip to content

Commit f54dc1a

Browse files
authored
Add unit tests for deprecation warning
Adds two unit tests to verify the deprecation warning behavior: 1. `test_deprecation_warning_for_to_backend_workflow`: Verifies that the warning IS logged when using the deprecated to_edge() + to_backend() workflow 2. `test_no_warning_for_to_edge_transform_and_lower_workflow`: Verifies that the warning is NOT logged when using the recommended to_edge_transform_and_lower() workflow Also adds necessary imports: - import io - import logging - from executorch.exir import to_edge, to_edge_transform_and_lower These tests ensure the deprecation warning only appears in the deprecated flow, following the same pattern as XNNPACK (PR #13209).
1 parent 3a58575 commit f54dc1a

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import copy
66
import sys
77
import unittest
8+
import io
9+
import logging
810

911
import coremltools as ct
1012

@@ -16,6 +18,7 @@
1618
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1719
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1820
from executorch.exir.backend.utils import format_delegated_graph
21+
from executorch.exir import to_edge, to_edge_transform_and_lower
1922

2023

2124
@torch.library.custom_op("unsupported::linear", mutates_args=())
@@ -346,3 +349,78 @@ def forward(self, x):
346349
test_runner.test_lower_full_graph()
347350
# test_runner.test_symint_arg()
348351
test_runner.test_take_over_constant_data_false()
352+
353+
def test_deprecation_warning_for_to_backend_workflow(self):
354+
"""
355+
Test that the deprecated to_edge + to_backend workflow shows a deprecation
356+
warning.
357+
"""
358+
class SimpleModel(torch.nn.Module):
359+
def __init__(self):
360+
super().__init__()
361+
self.linear = torch.nn.Linear(10, 5)
362+
363+
def forward(self, x):
364+
return self.linear(x)
365+
366+
model = SimpleModel()
367+
x = torch.randn(1, 10)
368+
369+
exported_model = torch.export.export(model, (x,))
370+
371+
# Capture log output to check for deprecation warning
372+
log_capture_string = io.StringIO()
373+
ch = logging.StreamHandler(log_capture_string)
374+
ch.setLevel(logging.WARNING)
375+
376+
logger = logging.getLogger(
377+
"executorch.backends.apple.coreml.partition.coreml_partitioner"
378+
)
379+
logger.addHandler(ch)
380+
logger.setLevel(logging.WARNING)
381+
382+
edge = to_edge(exported_model)
383+
partitioner = CoreMLPartitioner()
384+
385+
edge.to_backend(partitioner)
386+
387+
log_contents = log_capture_string.getvalue()
388+
self.assertIn("DEPRECATION WARNING", log_contents)
389+
self.assertIn("to_edge() + to_backend()", log_contents)
390+
self.assertIn("to_edge_transform_and_lower()", log_contents)
391+
392+
def test_no_warning_for_to_edge_transform_and_lower_workflow(self):
393+
"""
394+
Test that the recommended to_edge_transform_and_lower workflow does NOT
395+
show a deprecation warning.
396+
"""
397+
class SimpleModel(torch.nn.Module):
398+
def __init__(self):
399+
super().__init__()
400+
self.linear = torch.nn.Linear(10, 5)
401+
402+
def forward(self, x):
403+
return self.linear(x)
404+
405+
model = SimpleModel()
406+
x = torch.randn(1, 10)
407+
408+
exported_model = torch.export.export(model, (x,))
409+
410+
# Capture log output to check for deprecation warning
411+
log_capture_string = io.StringIO()
412+
ch = logging.StreamHandler(log_capture_string)
413+
ch.setLevel(logging.WARNING)
414+
415+
logger = logging.getLogger(
416+
"executorch.backends.apple.coreml.partition.coreml_partitioner"
417+
)
418+
logger.addHandler(ch)
419+
logger.setLevel(logging.WARNING)
420+
421+
partitioner = CoreMLPartitioner()
422+
423+
to_edge_transform_and_lower(exported_model, partitioner=[partitioner])
424+
425+
log_contents = log_capture_string.getvalue()
426+
self.assertNotIn("DEPRECATION WARNING", log_contents)

0 commit comments

Comments
 (0)