|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import unittest | 
|  | 8 | + | 
|  | 9 | +import torch | 
|  | 10 | +from executorch.backends.xnnpack.test.tester import Tester | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +class TestCos(unittest.TestCase): | 
|  | 14 | +    def setUp(self): | 
|  | 15 | +        torch._dynamo.reset() | 
|  | 16 | + | 
|  | 17 | +    class Cos(torch.nn.Module): | 
|  | 18 | +        def __init__(self): | 
|  | 19 | +            super().__init__() | 
|  | 20 | + | 
|  | 21 | +        def forward(self, x): | 
|  | 22 | +            z = torch.cos(x) | 
|  | 23 | +            return z | 
|  | 24 | + | 
|  | 25 | +    def _test_cos(self, inputs, legacy_mode: bool = False): | 
|  | 26 | +        tester = ( | 
|  | 27 | +            Tester(self.Cos(), inputs) | 
|  | 28 | +            .export() | 
|  | 29 | +            .check_count({"torch.ops.aten.cos.default": 1}) | 
|  | 30 | +        ) | 
|  | 31 | + | 
|  | 32 | +        if legacy_mode: | 
|  | 33 | +            tester = tester.to_edge().partition() | 
|  | 34 | +        else: | 
|  | 35 | +            tester = tester.to_edge_transform_and_lower() | 
|  | 36 | + | 
|  | 37 | +        ( | 
|  | 38 | +            tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | 
|  | 39 | +            .check_not(["executorch_exir_dialects_edge__ops_aten_cos_default"]) | 
|  | 40 | +            .to_executorch() | 
|  | 41 | +            .serialize() | 
|  | 42 | +            .run_method_and_compare_outputs() | 
|  | 43 | +        ) | 
|  | 44 | + | 
|  | 45 | +    def test_fp16_cos(self): | 
|  | 46 | +        inputs = ( | 
|  | 47 | +            torch.Tensor( | 
|  | 48 | +                [ | 
|  | 49 | +                    [0.0, 0.1, 0.5, 0.785398], | 
|  | 50 | +                    [-0.5, -0.785398, 1.5708, -1.5708], | 
|  | 51 | +                ], | 
|  | 52 | +            ).to(torch.float16), | 
|  | 53 | +        ) | 
|  | 54 | +        self._test_cos(inputs, legacy_mode=False) | 
|  | 55 | + | 
|  | 56 | +    def test_fp16_cos_legacy_mode(self): | 
|  | 57 | +        inputs = ( | 
|  | 58 | +            torch.Tensor( | 
|  | 59 | +                [ | 
|  | 60 | +                    [0.0, 0.1, 0.5, 0.785398], | 
|  | 61 | +                    [-0.5, -0.785398, 1.5708, -1.5708], | 
|  | 62 | +                ], | 
|  | 63 | +            ).to(torch.float16), | 
|  | 64 | +        ) | 
|  | 65 | +        self._test_cos(inputs, legacy_mode=True) | 
|  | 66 | + | 
|  | 67 | +    def test_fp32_cos(self): | 
|  | 68 | +        inputs = ( | 
|  | 69 | +            torch.Tensor( | 
|  | 70 | +                [ | 
|  | 71 | +                    [0.0, 0.1, 0.5, 0.785398], | 
|  | 72 | +                    [-0.5, -0.785398, 1.5708, -1.5708], | 
|  | 73 | +                ], | 
|  | 74 | +            ), | 
|  | 75 | +        ) | 
|  | 76 | +        self._test_cos(inputs, legacy_mode=False) | 
|  | 77 | + | 
|  | 78 | +    def test_fp32_cos_legacy_mode(self): | 
|  | 79 | +        inputs = ( | 
|  | 80 | +            torch.Tensor( | 
|  | 81 | +                [ | 
|  | 82 | +                    [0.0, 0.1, 0.5, 0.785398], | 
|  | 83 | +                    [-0.5, -0.785398, 1.5708, -1.5708], | 
|  | 84 | +                ], | 
|  | 85 | +            ), | 
|  | 86 | +        ) | 
|  | 87 | +        self._test_cos(inputs, legacy_mode=True) | 
0 commit comments