| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +import unittest  | 
 | 8 | +import logging  | 
 | 9 | +import io  | 
 | 10 | + | 
 | 11 | +import torch  | 
 | 12 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner  | 
 | 13 | +from executorch.exir import to_edge, to_edge_transform_and_lower  | 
 | 14 | +from torch.export import export  | 
 | 15 | + | 
 | 16 | + | 
 | 17 | +class TestXnnpackPartitioner(unittest.TestCase):  | 
 | 18 | +    """Test cases for XnnpackPartitioner functionality and deprecation warnings."""  | 
 | 19 | +    class SimpleModel(torch.nn.Module):  | 
 | 20 | +        def __init__(self):  | 
 | 21 | +            super().__init__()  | 
 | 22 | +            self.linear = torch.nn.Linear(10, 5)  | 
 | 23 | + | 
 | 24 | +        def forward(self, x):  | 
 | 25 | +            return self.linear(x)  | 
 | 26 | + | 
 | 27 | +    def test_deprecation_warning_for_to_backend_workflow(self):  | 
 | 28 | +        """  | 
 | 29 | +        Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.  | 
 | 30 | +        """  | 
 | 31 | +        model = self.SimpleModel()  | 
 | 32 | +        x = torch.randn(1, 10)  | 
 | 33 | + | 
 | 34 | +        exported_model = export(model, (x,))  | 
 | 35 | + | 
 | 36 | +        # Capture log output to check for deprecation warning  | 
 | 37 | +        log_capture_string = io.StringIO()  | 
 | 38 | +        ch = logging.StreamHandler(log_capture_string)  | 
 | 39 | +        ch.setLevel(logging.WARNING)  | 
 | 40 | + | 
 | 41 | +        logger = logging.getLogger('executorch.backends.xnnpack.partition.xnnpack_partitioner')  | 
 | 42 | +        logger.addHandler(ch)  | 
 | 43 | +        logger.setLevel(logging.WARNING)  | 
 | 44 | + | 
 | 45 | +        edge = to_edge(exported_model)  | 
 | 46 | +        partitioner = XnnpackPartitioner()  | 
 | 47 | + | 
 | 48 | +        edge.to_backend(partitioner)  | 
 | 49 | + | 
 | 50 | +        log_contents = log_capture_string.getvalue()  | 
 | 51 | +        self.assertIn("DEPRECATION WARNING", log_contents)  | 
 | 52 | +        self.assertIn("to_edge() + to_backend()", log_contents)  | 
 | 53 | +        self.assertIn("to_edge_transform_and_lower()", log_contents)  | 
 | 54 | + | 
 | 55 | +    def test_no_warning_for_to_edge_transform_and_lower_workflow(self):  | 
 | 56 | +        """  | 
 | 57 | +        Test that the recommended to_edge_transform_and_lower workflow does NOT show a deprecation warning.  | 
 | 58 | +        """  | 
 | 59 | + | 
 | 60 | +        model = self.SimpleModel()  | 
 | 61 | +        x = torch.randn(1, 10)  | 
 | 62 | + | 
 | 63 | +        exported_model = export(model, (x,))  | 
 | 64 | + | 
 | 65 | +        # Capture log output to check for deprecation warning  | 
 | 66 | +        log_capture_string = io.StringIO()  | 
 | 67 | +        ch = logging.StreamHandler(log_capture_string)  | 
 | 68 | +        ch.setLevel(logging.WARNING)  | 
 | 69 | + | 
 | 70 | +        logger = logging.getLogger('executorch.backends.xnnpack.partition.xnnpack_partitioner')  | 
 | 71 | +        logger.addHandler(ch)  | 
 | 72 | +        logger.setLevel(logging.WARNING)  | 
 | 73 | + | 
 | 74 | +        partitioner = XnnpackPartitioner()  | 
 | 75 | + | 
 | 76 | +        to_edge_transform_and_lower(  | 
 | 77 | +            exported_model,  | 
 | 78 | +            partitioner=[partitioner]  | 
 | 79 | +        )  | 
 | 80 | + | 
 | 81 | +        log_contents = log_capture_string.getvalue()  | 
 | 82 | +        self.assertNotIn("DEPRECATION WARNING", log_contents)  | 
0 commit comments