1+ # Copyright 2026 Sony Semiconductor Solutions, Inc. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ==============================================================================
15+ import pytest
16+ from tests_pytest ._test_util .graph_builder_utils import build_node
17+ from unittest .mock import Mock
18+ import numpy as np
19+ import torch
20+ from model_compression_toolkit .core .pytorch .utils import torch_tensor_to_numpy
21+ from model_compression_toolkit .core .common import StatsCollector , Graph
22+ from model_compression_toolkit .core .common .graph .base_graph import OutTensor
23+ from model_compression_toolkit .core .common .graph .edge import Edge
24+ from model_compression_toolkit .core .common .model_collector import ModelCollector
25+ from model_compression_toolkit .defaultdict import DefaultDict
26+
27+
28+ class Conv2D :
29+ pass
30+
31+ class Linear :
32+ pass
33+
34+ class ConvTranspose2d :
35+ pass
36+
37+ class DummyLayer :
38+ pass
39+
40+ @pytest .fixture
41+ def fw_impl_mock ():
42+ fw_impl = Mock ()
43+ fw_impl .model_builder .return_value = (Mock (), None )
44+ return fw_impl
45+
46+ @pytest .fixture
47+ def fw_info_mock ():
48+ fw_info = Mock ()
49+ fw_info .out_channel_axis_mapping = DefaultDict ({Conv2D : 1 , Linear : - 1 , ConvTranspose2d : 1 }, 1 )
50+ return fw_info
51+
52+
53+ class TestModelCollectorInit :
54+
55+ def test_init (self , fw_impl_mock , fw_info_mock ):
56+ node0 = build_node ('node0' , output_shape = (1 , 3 , 2 , 2 )) # 4D tensor
57+ node1 = build_node ('node1' , output_shape = (3 , 2 )) # 2D tensor
58+ node2 = build_node ('node2' , output_shape = (4 ,)) # 1D tensor
59+ node3 = build_node ('node3' , output_shape = ()) # Scalar
60+
61+ mock_nodes_list = [node0 , node1 , node2 , node3 ]
62+ for node in mock_nodes_list :
63+ node .is_activation_quantization_enabled = Mock (return_value = True )
64+ node .is_fln_quantization = Mock (return_value = False )
65+
66+ graph = Graph ('g' ,
67+ input_nodes = [node0 ],
68+ nodes = mock_nodes_list ,
69+ output_nodes = [OutTensor (node3 , 0 )],
70+ edge_list = [Edge (node0 , node1 , 0 , 0 ), Edge (node1 , node2 , 0 , 0 ), Edge (node2 , node3 , 0 , 0 )])
71+ graph .set_out_stats_collector_to_node = Mock (wraps = graph .set_out_stats_collector_to_node )
72+
73+ fw_info_mock .get_kernel_op_attributes .return_value = [None ]
74+
75+ mc = ModelCollector (graph , fw_impl_mock , fw_info_mock )
76+
77+ # If output shape is scalar or 1D tensor, the axis should be -1.
78+ # If output shape is 2D tensor, the axis should be 1.
79+ expected_axis = [1 , 1 , - 1 , - 1 ]
80+ for node , expected in zip (graph .nodes , expected_axis ):
81+ out_stats_container = graph .get_out_stats_collector (node )
82+ assert isinstance (out_stats_container , StatsCollector )
83+
84+ assert out_stats_container .mpcc .axis == expected
85+ assert out_stats_container .mc .axis == expected
86+
87+
88+ class TestModelCollectorInfer :
89+
90+ def test_infer (self , fw_impl_mock , fw_info_mock ):
91+ node0 = build_node ('node0' , output_shape = (1 , 3 , 2 , 2 )) # 4D tensor
92+ node1 = build_node ('node1' , output_shape = (3 , 2 )) # 2D tensor
93+ node2 = build_node ('node2' , output_shape = (4 ,)) # 1D tensor
94+ node3 = build_node ('node3' , output_shape = ()) # scalar
95+
96+ mock_nodes_list = [node0 , node1 , node2 , node3 ]
97+ for node in mock_nodes_list :
98+ node .is_activation_quantization_enabled = Mock (return_value = True )
99+ node .is_fln_quantization = Mock (return_value = False )
100+
101+ graph = Graph ('g' ,
102+ input_nodes = [node0 ],
103+ nodes = mock_nodes_list ,
104+ output_nodes = [OutTensor (node3 , 0 )],
105+ edge_list = [Edge (node0 , node1 , 0 , 0 ), Edge (node1 , node2 , 0 , 0 ), Edge (node2 , node3 , 0 , 0 )])
106+
107+ fw_info_mock .get_kernel_op_attributes .return_value = [None ]
108+
109+ infer1 = [
110+ torch .tensor (
111+ [[
112+ [[1.0 , 2.0 ], [3.0 , 4.0 ]],
113+ [[- 1.0 , - 2.0 ], [- 3.0 , - 4.0 ]],
114+ [[10.0 , 10.0 ], [10.0 , 10.0 ]],
115+ ]],
116+ dtype = torch .float32 ,
117+ ),
118+ torch .tensor ([[1.0 , 3.0 ], [2.0 , 4.0 ], [3.0 , 5.0 ]], dtype = torch .float32 ),
119+ torch .tensor ([2.0 , 6.0 , 8.0 , 10.0 ], dtype = torch .float32 ),
120+ torch .tensor (10.0 , dtype = torch .float32 ),
121+ ]
122+ infer2 = [
123+ torch .tensor (
124+ [[
125+ [[5.0 , 6.0 ], [7.0 , 8.0 ]],
126+ [[0.0 , 1.0 ], [2.0 , 3.0 ]],
127+ [[- 10.0 , - 20.0 ], [- 30.0 , - 40.0 ]],
128+ ]],
129+ dtype = torch .float32 ,
130+ ),
131+ torch .tensor ([[5.0 , - 1.0 ], [6.0 , - 2.0 ], [7.0 , - 3.0 ]], dtype = torch .float32 ),
132+ torch .tensor ([4.0 , 8.0 , 12.0 , 16.0 ], dtype = torch .float32 ),
133+ torch .tensor (- 2.0 , dtype = torch .float32 ),
134+ ]
135+
136+ fw_impl_mock .to_numpy .side_effect = torch_tensor_to_numpy
137+ fw_impl_mock .run_model_inference .side_effect = [infer1 , infer2 ]
138+
139+ mc = ModelCollector (graph , fw_impl_mock , fw_info_mock )
140+
141+ dummy_input = [np .random .randn (1 , 3 , 2 , 2 )]
142+ mc .infer (dummy_input )
143+ mc .infer (dummy_input )
144+
145+ sc0 = graph .get_out_stats_collector (node0 )
146+ sc1 = graph .get_out_stats_collector (node1 )
147+ sc2 = graph .get_out_stats_collector (node2 )
148+ sc3 = graph .get_out_stats_collector (node3 )
149+
150+ # node0 (axis=1)
151+ # infer1 channel means: [2.5, -2.5, 10.0]
152+ # infer2 channel means: [6.5, 1.5, -25.0]
153+ # final mean: [4.5, -0.5, -7.5]
154+ np .testing .assert_allclose (sc0 .get_mean (), np .array ([4.5 , - 0.5 , - 7.5 ]))
155+ min_v , max_v = sc0 .get_min_max_values ()
156+ assert min_v == - 40.0
157+ assert max_v == 10.0
158+
159+ # node1 (axis=1)
160+ # infer1 channel means: [2, 4]
161+ # infer2 channel means: [6, -2]
162+ # final mean: [4, 1]
163+ np .testing .assert_allclose (sc1 .get_mean (), np .array ([4.0 , 1.0 ]))
164+ min_v , max_v = sc1 .get_min_max_values ()
165+ assert min_v == - 3.0
166+ assert max_v == 7.0
167+
168+ # node2 (axis=-1)
169+ # infer1 channel means: [2, 6, 8, 10]
170+ # infer2 channel means: [4, 8, 12, 16]
171+ # final mean: [3, 7, 10, 13]
172+ np .testing .assert_allclose (sc2 .get_mean (), np .array ([3.0 , 7.0 , 10.0 , 13.0 ]))
173+ min_v , max_v = sc2 .get_min_max_values ()
174+ assert min_v == 2.0
175+ assert max_v == 16.0
176+
177+ # node3 (axis=-1)
178+ # infer1 channel means: 10
179+ # infer2 channel means: -2
180+ # final mean: 4
181+ np .testing .assert_allclose (sc3 .get_mean (), np .array ([4.0 ]))
182+ min_v , max_v = sc3 .get_min_max_values ()
183+ assert min_v == - 2.0
184+ assert max_v == 10.0
0 commit comments