1- # Copyright 2024 Arm Limited and/or its affiliates.
1+ # Copyright 2024-2025 Arm Limited and/or its affiliates.
22# All rights reserved.
33#
44# This source code is licensed under the BSD-style license found in the
77import unittest
88
99import torch
10- from executorch .backends .arm .test import common
10+ from executorch .backends .arm .test import common , conftest
1111from executorch .backends .arm .test .tester .arm_tester import ArmTester
1212from parameterized import parameterized
1313
1414
1515class TestRshift (unittest .TestCase ):
16- """
17- Tests arithmetic right shift
18- """
16+ """Tests arithmetic right shift"""
1917
2018 class Rshift (torch .nn .Module ):
2119 test_data = [
2220 ((torch .IntTensor (5 , 5 ), 2 ),),
2321 ((torch .IntTensor (1 , 2 , 3 , 4 ), 3 ),),
22+ ((torch .CharTensor (1 , 12 , 3 , 4 ), 1 ),),
2423 ((torch .ShortTensor (1 , 5 , 3 , 4 ), 5 ),),
25- ((torch .CharTensor (10 , 12 , 3 , 4 ), 1 ),),
2624 ]
2725
2826 def forward (self , x : torch .Tensor , shift : int ):
@@ -52,8 +50,7 @@ def _test_rshift_tosa_BI(self, test_data):
5250 .export ()
5351 .to_edge_transform_and_lower ()
5452 .to_executorch ()
55- # TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO
56- # .run_method_and_compare_outputs(inputs=test_data)
53+ .run_method_and_compare_outputs (inputs = test_data )
5754 )
5855
5956 def _test_rshift_ethosu_BI (self , test_data , compile_spec ):
@@ -67,6 +64,7 @@ def _test_rshift_ethosu_BI(self, test_data, compile_spec):
6764 .export ()
6865 .to_edge_transform_and_lower ()
6966 .to_executorch ()
67+ .serialize ()
7068 )
7169
7270 @parameterized .expand (Rshift .test_data )
@@ -77,14 +75,18 @@ def test_rshift_tosa_MI(self, test_data):
7775 def test_rshift_tosa_BI (self , test_data ):
7876 self ._test_rshift_tosa_BI (test_data )
7977
80- # TODO Enable FVP testing
81- @parameterized .expand (Rshift .test_data )
78+ # TODO: MLETORCH-644 - Add support for INT16 input/output
79+ @parameterized .expand (Rshift .test_data [: - 1 ] )
8280 def test_rshift_u55_BI (self , test_data ):
8381 compile_spec = common .get_u55_compile_spec ()
84- self ._test_rshift_ethosu_BI (test_data , compile_spec )
82+ tester = self ._test_rshift_ethosu_BI (test_data , compile_spec )
83+ if conftest .is_option_enabled ("corstone_fvp" ):
84+ tester .run_method_and_compare_outputs (atol = 1 , inputs = test_data )
8585
86- # TODO Enable FVP testing
87- @parameterized .expand (Rshift .test_data )
86+ # TODO: MLETORCH-644 - Add support for INT16 input/output
87+ @parameterized .expand (Rshift .test_data [: - 1 ] )
8888 def test_rshift_u85_BI (self , test_data ):
8989 compile_spec = common .get_u85_compile_spec ()
90- self ._test_rshift_ethosu_BI (test_data , compile_spec )
90+ tester = self ._test_rshift_ethosu_BI (test_data , compile_spec )
91+ if conftest .is_option_enabled ("corstone_fvp" ):
92+ tester .run_method_and_compare_outputs (inputs = test_data )
0 commit comments