Skip to content

Commit f2c2e9d

Browse files
Add integration test
1 parent 25b9b1d commit f2c2e9d

File tree

3 files changed

+276
-150
lines changed

3 files changed

+276
-150
lines changed

tests_pytest/common_tests/unit_tests/core/test_create_stats_collector_for_node.py

Lines changed: 0 additions & 150 deletions
This file was deleted.
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright 2026 Sony Semiconductor Solutions, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import pytest
16+
from tests_pytest._test_util.graph_builder_utils import build_node
17+
from unittest.mock import Mock
18+
import numpy as np
19+
import torch
20+
from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
21+
from model_compression_toolkit.core.common import StatsCollector, Graph
22+
from model_compression_toolkit.core.common.graph.base_graph import OutTensor
23+
from model_compression_toolkit.core.common.graph.edge import Edge
24+
from model_compression_toolkit.core.common.model_collector import ModelCollector
25+
from model_compression_toolkit.defaultdict import DefaultDict
26+
27+
28+
class Conv2D:
29+
pass
30+
31+
class Linear:
32+
pass
33+
34+
class ConvTranspose2d:
35+
pass
36+
37+
class DummyLayer:
38+
pass
39+
40+
@pytest.fixture
41+
def fw_impl_mock():
42+
fw_impl = Mock()
43+
fw_impl.model_builder.return_value = (Mock(), None)
44+
return fw_impl
45+
46+
@pytest.fixture
47+
def fw_info_mock():
48+
fw_info = Mock()
49+
fw_info.out_channel_axis_mapping = DefaultDict({Conv2D: 1, Linear: -1, ConvTranspose2d: 1}, 1)
50+
return fw_info
51+
52+
53+
class TestModelCollectorInit:
54+
55+
def test_init(self, fw_impl_mock, fw_info_mock):
56+
node0 = build_node('node0', output_shape=(1, 3, 2, 2)) # 4D tensor
57+
node1 = build_node('node1', output_shape=(3, 2)) # 2D tensor
58+
node2 = build_node('node2', output_shape=(4,)) # 1D tensor
59+
node3 = build_node('node3', output_shape=()) # Scalar
60+
61+
mock_nodes_list = [node0, node1, node2, node3]
62+
for node in mock_nodes_list:
63+
node.is_activation_quantization_enabled = Mock(return_value=True)
64+
node.is_fln_quantization = Mock(return_value=False)
65+
66+
graph = Graph('g',
67+
input_nodes=[node0],
68+
nodes=mock_nodes_list,
69+
output_nodes=[OutTensor(node3, 0)],
70+
edge_list=[Edge(node0, node1, 0, 0), Edge(node1, node2, 0, 0), Edge(node2, node3, 0, 0)])
71+
graph.set_out_stats_collector_to_node = Mock(wraps=graph.set_out_stats_collector_to_node)
72+
73+
fw_info_mock.get_kernel_op_attributes.return_value = [None]
74+
75+
mc = ModelCollector(graph, fw_impl_mock, fw_info_mock)
76+
77+
# If output shape is scalar or 1D tensor, the axis should be -1.
78+
# If output shape is 2D tensor, the axis should be 1.
79+
expected_axis = [1, 1, -1, -1]
80+
for node, expected in zip(graph.nodes, expected_axis):
81+
out_stats_container = graph.get_out_stats_collector(node)
82+
assert isinstance(out_stats_container, StatsCollector)
83+
84+
assert out_stats_container.mpcc.axis == expected
85+
assert out_stats_container.mc.axis == expected
86+
87+
88+
class TestModelCollectorInfer:
89+
90+
def test_infer(self, fw_impl_mock, fw_info_mock):
91+
node0 = build_node('node0', output_shape=(1, 3, 2, 2)) # 4D tensor
92+
node1 = build_node('node1', output_shape=(3, 2)) # 2D tensor
93+
node2 = build_node('node2', output_shape=(4,)) # 1D tensor
94+
node3 = build_node('node3', output_shape=()) # scalar
95+
96+
mock_nodes_list = [node0, node1, node2, node3]
97+
for node in mock_nodes_list:
98+
node.is_activation_quantization_enabled = Mock(return_value=True)
99+
node.is_fln_quantization = Mock(return_value=False)
100+
101+
graph = Graph('g',
102+
input_nodes=[node0],
103+
nodes=mock_nodes_list,
104+
output_nodes=[OutTensor(node3, 0)],
105+
edge_list=[Edge(node0, node1, 0, 0), Edge(node1, node2, 0, 0), Edge(node2, node3, 0, 0)])
106+
107+
fw_info_mock.get_kernel_op_attributes.return_value = [None]
108+
109+
infer1 = [
110+
torch.tensor(
111+
[[
112+
[[1.0, 2.0], [3.0, 4.0]],
113+
[[-1.0, -2.0], [-3.0, -4.0]],
114+
[[10.0, 10.0], [10.0, 10.0]],
115+
]],
116+
dtype=torch.float32,
117+
),
118+
torch.tensor([[1.0, 3.0], [2.0, 4.0], [3.0, 5.0]], dtype=torch.float32),
119+
torch.tensor([2.0, 6.0, 8.0, 10.0], dtype=torch.float32),
120+
torch.tensor(10.0, dtype=torch.float32),
121+
]
122+
infer2 = [
123+
torch.tensor(
124+
[[
125+
[[5.0, 6.0], [7.0, 8.0]],
126+
[[0.0, 1.0], [2.0, 3.0]],
127+
[[-10.0, -20.0], [-30.0, -40.0]],
128+
]],
129+
dtype=torch.float32,
130+
),
131+
torch.tensor([[5.0, -1.0], [6.0, -2.0], [7.0, -3.0]], dtype=torch.float32),
132+
torch.tensor([4.0, 8.0, 12.0, 16.0], dtype=torch.float32),
133+
torch.tensor(-2.0, dtype=torch.float32),
134+
]
135+
136+
fw_impl_mock.to_numpy.side_effect = torch_tensor_to_numpy
137+
fw_impl_mock.run_model_inference.side_effect = [infer1, infer2]
138+
139+
mc = ModelCollector(graph, fw_impl_mock, fw_info_mock)
140+
141+
dummy_input = [np.random.randn(1, 3, 2, 2)]
142+
mc.infer(dummy_input)
143+
mc.infer(dummy_input)
144+
145+
sc0 = graph.get_out_stats_collector(node0)
146+
sc1 = graph.get_out_stats_collector(node1)
147+
sc2 = graph.get_out_stats_collector(node2)
148+
sc3 = graph.get_out_stats_collector(node3)
149+
150+
# node0 (axis=1)
151+
# infer1 channel means: [2.5, -2.5, 10.0]
152+
# infer2 channel means: [6.5, 1.5, -25.0]
153+
# final mean: [4.5, -0.5, -7.5]
154+
np.testing.assert_allclose(sc0.get_mean(), np.array([4.5, -0.5, -7.5]))
155+
min_v, max_v = sc0.get_min_max_values()
156+
assert min_v == -40.0
157+
assert max_v == 10.0
158+
159+
# node1 (axis=1)
160+
# infer1 channel means: [2, 4]
161+
# infer2 channel means: [6, -2]
162+
# final mean: [4, 1]
163+
np.testing.assert_allclose(sc1.get_mean(), np.array([4.0, 1.0]))
164+
min_v, max_v = sc1.get_min_max_values()
165+
assert min_v == -3.0
166+
assert max_v == 7.0
167+
168+
# node2 (axis=-1)
169+
# infer1 channel means: [2, 6, 8, 10]
170+
# infer2 channel means: [4, 8, 12, 16]
171+
# final mean: [3, 7, 10, 13]
172+
np.testing.assert_allclose(sc2.get_mean(), np.array([3.0, 7.0, 10.0, 13.0]))
173+
min_v, max_v = sc2.get_min_max_values()
174+
assert min_v == 2.0
175+
assert max_v == 16.0
176+
177+
# node3 (axis=-1)
178+
# infer1 channel means: 10
179+
# infer2 channel means: -2
180+
# final mean: 4
181+
np.testing.assert_allclose(sc3.get_mean(), np.array([4.0]))
182+
min_v, max_v = sc3.get_min_max_values()
183+
assert min_v == -2.0
184+
assert max_v == 10.0

0 commit comments

Comments
 (0)