@@ -29,6 +29,9 @@ def setUp(self):
2929 self .list_of_numbers = [1 , 2 , 3 ]
3030 self .tuple_of_numbers = (1 , 2 , 3 )
3131
32+ self .scalar_numpy_array = np .array ([1.25 ])
33+ self .scalar_torch_tensor = torch .tensor (1.25 )
34+
3235 @patch ('model_compression_toolkit.core.pytorch.pytorch_device_config.get_working_device' )
3336 def test_to_torch_tensor_with_numpy_array (self , mock_get_device ):
3437 mock_get_device .return_value = 'cpu'
@@ -69,21 +72,31 @@ def test_torch_tensor_to_numpy_with_torch_tensor(self):
6972 np .testing .assert_array_almost_equal (result , self .numpy_array )
7073
7174 def test_torch_tensor_to_numpy_with_scalar_tensor (self ):
72- scalar_tensor = torch .tensor (1.25 )
73- result = torch_tensor_to_numpy (scalar_tensor )
75+ result = torch_tensor_to_numpy (self .scalar_torch_tensor )
7476 self .assertEqual (result .shape , (1 ,))
75- np .testing .assert_array_almost_equal (result , np . array ([ 1.25 ]) )
77+ np .testing .assert_array_almost_equal (result , self . scalar_numpy_array )
7678
7779 def test_torch_tensor_to_numpy_with_list (self ):
7880 result = torch_tensor_to_numpy ([self .torch_tensor , self .torch_tensor ])
7981 self .assertEqual (len (result ), 2 )
8082 self .assertTrue (all (isinstance (x , np .ndarray ) for x in result ))
8183
84+ def test_torch_tensor_to_numpy_with_scalar_list (self ):
85+ result = torch_tensor_to_numpy ([self .scalar_torch_tensor , self .scalar_torch_tensor ])
86+ self .assertEqual (len (result ), 2 )
87+ self .assertTrue (all (isinstance (x , np .ndarray ) for x in result ))
88+ self .assertTrue (all (x .shape == (1 ,) for x in result ))
89+
8290 def test_torch_tensor_to_numpy_with_tuple (self ):
8391 result = torch_tensor_to_numpy ((self .torch_tensor , self .torch_tensor ))
8492 self .assertEqual (len (result ), 2 )
8593 self .assertTrue (all (isinstance (x , np .ndarray ) for x in result ))
8694
95+ def test_torch_tensor_to_numpy_with_scalar_tuple (self ):
96+ result = torch_tensor_to_numpy ((self .scalar_torch_tensor , self .scalar_torch_tensor ))
97+ self .assertEqual (len (result ), 2 )
98+ self .assertTrue (all (isinstance (x , np .ndarray ) for x in result ))
99+
87100 @patch ('model_compression_toolkit.logger.Logger' )
88101 def test_torch_tensor_to_numpy_with_unsupported_type (self , mock_logger ):
89102 with self .assertRaises (Exception ):
0 commit comments