Skip to content

Commit dc6fc92

Browse files
Add test
1 parent 1e366b6 commit dc6fc92

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

model_compression_toolkit/core/common/model_collector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def infer(self, inputs_list: List[np.ndarray]):
256256
hessian_tensors += [None for _ in range(len(activation_tensors) - len(hessian_tensors))]
257257

258258
for activation_tensor, hessian_tensor, stats_container in zip(activation_tensors, hessian_tensors, self.stats_containers_list):
259-
print('activation_tensor', activation_tensor.shape)
260259
if isinstance(stats_container, (list, tuple)):
261260
if hessian_tensor is None:
262261
hessian_tensor = [None for _ in range(len(activation_tensor))]

tests/pytorch_tests/function_tests/test_torch_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def setUp(self):
2929
self.list_of_numbers = [1, 2, 3]
3030
self.tuple_of_numbers = (1, 2, 3)
3131

32+
self.scalar_numpy_array = np.array([1.25])
33+
self.scalar_torch_tensor = torch.tensor(1.25)
34+
3235
@patch('model_compression_toolkit.core.pytorch.pytorch_device_config.get_working_device')
3336
def test_to_torch_tensor_with_numpy_array(self, mock_get_device):
3437
mock_get_device.return_value = 'cpu'
@@ -69,21 +72,31 @@ def test_torch_tensor_to_numpy_with_torch_tensor(self):
6972
np.testing.assert_array_almost_equal(result, self.numpy_array)
7073

7174
def test_torch_tensor_to_numpy_with_scalar_tensor(self):
72-
scalar_tensor = torch.tensor(1.25)
73-
result = torch_tensor_to_numpy(scalar_tensor)
75+
result = torch_tensor_to_numpy(self.scalar_torch_tensor)
7476
self.assertEqual(result.shape, (1,))
75-
np.testing.assert_array_almost_equal(result, np.array([1.25]))
77+
np.testing.assert_array_almost_equal(result, self.scalar_numpy_array)
7678

7779
def test_torch_tensor_to_numpy_with_list(self):
7880
result = torch_tensor_to_numpy([self.torch_tensor, self.torch_tensor])
7981
self.assertEqual(len(result), 2)
8082
self.assertTrue(all(isinstance(x, np.ndarray) for x in result))
8183

84+
def test_torch_tensor_to_numpy_with_scalar_list(self):
85+
result = torch_tensor_to_numpy([self.scalar_torch_tensor, self.scalar_torch_tensor])
86+
self.assertEqual(len(result), 2)
87+
self.assertTrue(all(isinstance(x, np.ndarray) for x in result))
88+
self.assertTrue(all(x.shape == (1,) for x in result))
89+
8290
def test_torch_tensor_to_numpy_with_tuple(self):
8391
result = torch_tensor_to_numpy((self.torch_tensor, self.torch_tensor))
8492
self.assertEqual(len(result), 2)
8593
self.assertTrue(all(isinstance(x, np.ndarray) for x in result))
8694

95+
def test_torch_tensor_to_numpy_with_scalar_tuple(self):
96+
result = torch_tensor_to_numpy((self.scalar_torch_tensor, self.scalar_torch_tensor))
97+
self.assertEqual(len(result), 2)
98+
self.assertTrue(all(isinstance(x, np.ndarray) for x in result))
99+
87100
@patch('model_compression_toolkit.logger.Logger')
88101
def test_torch_tensor_to_numpy_with_unsupported_type(self, mock_logger):
89102
with self.assertRaises(Exception):

0 commit comments

Comments
 (0)