Skip to content

Commit b70bb65

Browse files
committed
add unit test to mp custom metric
1 parent 85b0de7 commit b70bb65

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import pytest
16+
import numpy as np
17+
18+
from unittest.mock import Mock
19+
20+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
21+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
22+
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
23+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
24+
25+
26+
def custom_float_metric(model_mp) -> float:
27+
return 100.0
28+
29+
30+
def custom_np_float_metric(model_mp) -> np.floating:
31+
return np.float64(100.0)
32+
33+
34+
def custom_str_metric(model_mp) -> str:
35+
return 'test'
36+
37+
38+
def custom_none_metric(model_mp):
39+
return None
40+
41+
42+
@pytest.fixture
43+
def mock_model_configuration():
44+
return Mock()
45+
46+
47+
@pytest.fixture
48+
def sensitivity_evaluator_factory():
49+
def _create_sensitivity_evaluator(custom_metric_fn):
50+
mp_cfg = MixedPrecisionQuantizationConfig(custom_metric_fn=custom_metric_fn)
51+
return get_sensitivity_evaluator(mp_cfg=mp_cfg)
52+
return _create_sensitivity_evaluator
53+
54+
55+
def get_sensitivity_evaluator(mp_cfg):
56+
mock_graph = Mock()
57+
mock_graph.get_topo_sorted_nodes.return_value = ['test', 'this', 'is', 'reset']
58+
mock_graph.get_outputs.return_value = []
59+
60+
def representative_data_gen() -> list:
61+
for _ in range(5):
62+
yield np.random.randn(2, 3, 248, 248)
63+
64+
mock_fw_info = Mock()
65+
66+
def custom_model_builder_return_value(*args, **kwargs):
67+
mode = kwargs.get('mode')
68+
if mode == ModelBuilderMode.FLOAT:
69+
return (None, None)
70+
else:
71+
return (None, None, None)
72+
73+
def custom_to_tensor(img):
74+
return to_torch_tensor(img)
75+
76+
mock_fw_impl = Mock()
77+
mock_fw_impl.model_builder.side_effect = custom_model_builder_return_value
78+
mock_fw_impl.to_tensor.side_effect = custom_to_tensor
79+
80+
mock_set_layer_to_bitwidth = Mock()
81+
82+
sensitivity_eval = SensitivityEvaluation(graph=mock_graph,
83+
quant_config=mp_cfg,
84+
representative_data_gen=representative_data_gen,
85+
fw_info=mock_fw_info,
86+
fw_impl=mock_fw_impl,
87+
set_layer_to_bitwidth=mock_set_layer_to_bitwidth
88+
)
89+
sensitivity_eval._configure_bitwidths_model = lambda *args, **kwargs: None # Method does nothing
90+
sensitivity_eval.model_mp = Mock()
91+
return sensitivity_eval
92+
93+
94+
class TestMPCustomMetricFunction:
95+
96+
@pytest.mark.parametrize("metric_fn, expected", [
97+
(custom_float_metric, 100.0),
98+
(custom_np_float_metric, np.float64(100.0)),
99+
])
100+
def test_valid_metric_function(self, sensitivity_evaluator_factory, mock_model_configuration, metric_fn, expected):
101+
sensitivity_eval = sensitivity_evaluator_factory(metric_fn)
102+
assert len(sensitivity_eval.interest_points) == 0
103+
assert sensitivity_eval.compute_metric(mock_model_configuration) == expected
104+
105+
@pytest.mark.parametrize("metric_fn, expected", [
106+
(custom_str_metric, str.__name__),
107+
(custom_none_metric, type(None).__name__),
108+
])
109+
def test_type_invalid_metric_function(self, sensitivity_evaluator_factory, mock_model_configuration, metric_fn, expected):
110+
sensitivity_eval = sensitivity_evaluator_factory(metric_fn)
111+
assert len(sensitivity_eval.interest_points) == 0
112+
with pytest.raises(TypeError, match=f'The custom_metric_fn is expected to return float or numpy float, got {expected}'):
113+
sensitivity_metric = sensitivity_eval.compute_metric(mock_model_configuration)

0 commit comments

Comments
 (0)