Skip to content

Commit 357b42b

Browse files
authored
add support for custom metric function for mixed precision (#1420)
* add support for custom metric function for mixed precision
1 parent 6345f8d commit 357b42b

File tree

4 files changed

+120
-4
lines changed

4 files changed

+120
-4
lines changed

model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class MixedPrecisionQuantizationConfig:
2727
Args:
2828
compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
2929
distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
30+
custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a float value for metric. If None, uses interest point metric.
3031
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
3132
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
3233
num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
@@ -39,6 +40,7 @@ class MixedPrecisionQuantizationConfig:
3940

4041
compute_distance_fn: Optional[Callable] = None
4142
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
43+
custom_metric_fn: Optional[Callable] = None
4244
num_of_images: int = MP_DEFAULT_NUM_SAMPLES
4345
configuration_overwrite: Optional[List[int]] = None
4446
num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def compute_metric(cfg, node_idx=None, baseline_cfg=None):
169169
return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
170170
node_idx,
171171
topo_cfg(baseline_cfg) if baseline_cfg else None)
172+
172173
if self.using_virtual_graph:
173174
origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
174175
self.max_ru_config)

model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,19 @@ def __init__(self,
8989
self.interest_points = get_mp_interest_points(graph,
9090
fw_impl.count_node_for_mixed_precision_interest_points,
9191
quant_config.num_interest_points_factor)
92+
# If using a custom metric - return only model outputs
93+
if self.quant_config.custom_metric_fn is not None:
94+
self.interest_points = []
9295

9396
# We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
9497
# because hessian weights already do normalization.
9598
use_normalized_mse = self.quant_config.use_hessian_based_scores is False
9699
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)
97100

98101
self.output_points = get_output_nodes_for_metric(graph)
102+
# If using a custom metric - return all model outputs
103+
if self.quant_config.custom_metric_fn is not None:
104+
self.output_points = [n.node for n in graph.get_outputs()]
99105
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points,
100106
use_normalized_mse)
101107

@@ -160,7 +166,7 @@ def compute_metric(self,
160166
"""
161167
Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
162168
is computed based on the similarity of the interest points' outputs between the MP model
163-
and the float model).
169+
and the float model or a custom metric if given).
164170
165171
Args:
166172
mp_model_configuration: Bitwidth configuration to use to configure the MP model.
@@ -177,15 +183,21 @@ def compute_metric(self,
177183
node_idx)
178184

179185
# Compute the distance metric
180-
ipts_distances, out_pts_distances = self._compute_distance()
186+
if self.quant_config.custom_metric_fn is None:
187+
ipts_distances, out_pts_distances = self._compute_distance()
188+
sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
189+
self.quant_config.distance_weighting_method)
190+
else:
191+
sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
192+
if not isinstance(sensitivity_metric, (float, np.floating)):
193+
raise TypeError(f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
181194

182195
# Configure MP model back to the same configuration as the baseline model if baseline provided
183196
if baseline_mp_configuration is not None:
184197
self._configure_bitwidths_model(baseline_mp_configuration,
185198
node_idx)
186199

187-
return self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
188-
self.quant_config.distance_weighting_method)
200+
return sensitivity_metric
189201

190202
def _init_baseline_tensors_list(self):
191203
"""
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import pytest
16+
import numpy as np
17+
18+
from unittest.mock import Mock
19+
20+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
21+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
22+
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
23+
24+
25+
def custom_float_metric(model_mp) -> float:
26+
return 100.0
27+
28+
29+
def custom_np_float_metric(model_mp) -> np.floating:
30+
return np.float64(100.0)
31+
32+
33+
def custom_str_metric(model_mp) -> str:
34+
return 'test'
35+
36+
37+
def custom_none_metric(model_mp):
38+
return None
39+
40+
41+
def get_sensitivity_evaluator(custom_metric_fn):
42+
mock_graph = Mock()
43+
mock_graph.get_topo_sorted_nodes.return_value = ['test', 'this', 'is', 'reset']
44+
mock_graph.get_outputs.return_value = []
45+
46+
def representative_data_gen() -> list:
47+
for _ in range(5):
48+
yield np.random.randn(2, 3, 248, 248)
49+
50+
mock_fw_info = Mock()
51+
52+
def custom_model_builder_return_value(*args, **kwargs):
53+
mode = kwargs.get('mode')
54+
if mode == ModelBuilderMode.FLOAT:
55+
return (None, None)
56+
else:
57+
return (None, None, None)
58+
59+
def custom_to_tensor(img):
60+
return img
61+
62+
mock_fw_impl = Mock()
63+
mock_fw_impl.model_builder.side_effect = custom_model_builder_return_value
64+
mock_fw_impl.to_tensor.side_effect = custom_to_tensor
65+
66+
mock_set_layer_to_bitwidth = Mock()
67+
68+
mp_cfg = MixedPrecisionQuantizationConfig(custom_metric_fn=custom_metric_fn)
69+
70+
sensitivity_eval = SensitivityEvaluation(graph=mock_graph,
71+
quant_config=mp_cfg,
72+
representative_data_gen=representative_data_gen,
73+
fw_info=mock_fw_info,
74+
fw_impl=mock_fw_impl,
75+
set_layer_to_bitwidth=mock_set_layer_to_bitwidth
76+
)
77+
sensitivity_eval._configure_bitwidths_model = lambda *args, **kwargs: None # Method does nothing
78+
sensitivity_eval.model_mp = Mock()
79+
return sensitivity_eval
80+
81+
82+
class TestMPCustomMetricFunction:
83+
84+
@pytest.mark.parametrize("metric_fn, expected", [
85+
(custom_float_metric, 100.0),
86+
(custom_np_float_metric, np.float64(100.0)),
87+
])
88+
def test_valid_metric_function(self, metric_fn, expected):
89+
sensitivity_eval = get_sensitivity_evaluator(metric_fn)
90+
assert len(sensitivity_eval.interest_points) == 0
91+
assert sensitivity_eval.compute_metric(Mock()) == expected
92+
93+
@pytest.mark.parametrize("metric_fn, expected", [
94+
(custom_str_metric, str.__name__),
95+
(custom_none_metric, type(None).__name__),
96+
])
97+
def test_type_invalid_metric_function(self, metric_fn, expected):
98+
sensitivity_eval = get_sensitivity_evaluator(metric_fn)
99+
assert len(sensitivity_eval.interest_points) == 0
100+
with pytest.raises(TypeError, match=f'The custom_metric_fn is expected to return float or numpy float, got {expected}'):
101+
sensitivity_metric = sensitivity_eval.compute_metric(Mock())

0 commit comments

Comments
 (0)