| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +import unittest  | 
 | 8 | +from typing import Optional, Tuple  | 
 | 9 | + | 
 | 10 | +import torch  | 
 | 11 | +from executorch.backends.xnnpack.partition.config.xnnpack_config import (  | 
 | 12 | +    ConfigPrecisionType,  | 
 | 13 | +)  | 
 | 14 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner  | 
 | 15 | +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (  | 
 | 16 | +    get_symmetric_quantization_config,  | 
 | 17 | +)  | 
 | 18 | +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (  | 
 | 19 | +    QuantizationConfig,  | 
 | 20 | +)  | 
 | 21 | +from executorch.backends.xnnpack.test.tester import Quantize, Tester  | 
 | 22 | +from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower  | 
 | 23 | + | 
 | 24 | +class TestChannelsLastTaggedReshapePass(unittest.TestCase):  | 
 | 25 | +    def setUp(self):  | 
 | 26 | +        torch._dynamo.reset()  | 
 | 27 | + | 
 | 28 | +    def run_tester(self, module, inputs):  | 
 | 29 | +        tester = Tester(  | 
 | 30 | +            module.eval(),  | 
 | 31 | +            inputs,  | 
 | 32 | +        )  | 
 | 33 | +        tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()  | 
 | 34 | + | 
 | 35 | +    class ChannelLastBeforeLinear(torch.nn.Module):  | 
 | 36 | +        def __init__(self):  | 
 | 37 | +            super().__init__()  | 
 | 38 | +            self.linear = torch.nn.Linear(3, 3)  | 
 | 39 | + | 
 | 40 | +        def forward(self, x):  | 
 | 41 | +            y = x.to(memory_format=torch.channels_last)  | 
 | 42 | +            return self.linear(y)  | 
 | 43 | + | 
 | 44 | +    ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()  | 
 | 45 | +    def test_channel_last_before_linear(self):  | 
 | 46 | +        self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))  | 
 | 47 | + | 
 | 48 | + | 
 | 49 | +    class ContiguousBeforeConv(torch.nn.Module):  | 
 | 50 | +        def __init__(self):  | 
 | 51 | +            super().__init__()  | 
 | 52 | +            self.conv = torch.nn.Conv2d(3, 3, 3)  | 
 | 53 | + | 
 | 54 | +        def forward(self, x):  | 
 | 55 | +            y = x.to(memory_format=torch.contiguous_format)  | 
 | 56 | +            return self.conv(y)  | 
 | 57 | + | 
 | 58 | +    ContiguousBeforeConvModule = ContiguousBeforeConv()  | 
 | 59 | +    def test_contiguous_before_conv(self):  | 
 | 60 | +        self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))  | 
 | 61 | + | 
 | 62 | +    class DtypeAndMemoryFormatConversion(torch.nn.Module):  | 
 | 63 | +        def __init__(self):  | 
 | 64 | +            super().__init__()  | 
 | 65 | +            self.conv = torch.nn.Conv2d(3, 3, 3)  | 
 | 66 | + | 
 | 67 | +        def forward(self, x):  | 
 | 68 | +            y = x.to(torch.float, memory_format=torch.channels_last)  | 
 | 69 | +            return self.conv(y)  | 
 | 70 | + | 
 | 71 | +    DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion()  | 
 | 72 | +    def test_dtype_and_memory_format_conversion(self):  | 
 | 73 | +        self.run_tester(self.DtypeAndMemoryFormatConversionModule, (torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),))  | 
 | 74 | + | 
 | 75 | +    class DtypeAndMemoryFormatWithLinear(torch.nn.Module):  | 
 | 76 | +        def __init__(self):  | 
 | 77 | +            super().__init__()  | 
 | 78 | +            self.linear = torch.nn.Linear(3, 3)  | 
 | 79 | + | 
 | 80 | +        def forward(self, x):  | 
 | 81 | +            y = x.to(torch.float, memory_format=torch.channels_last)  | 
 | 82 | +            return self.linear(y)  | 
 | 83 | + | 
 | 84 | +    DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear()  | 
 | 85 | +    def test_dtype_and_memory_format_with_linear(self):  | 
 | 86 | +        self.run_tester(self.DtypeAndMemoryFormatWithLinearModule, (torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),))  | 
0 commit comments