diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py index d729c6924..c3e6a742d 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py @@ -27,6 +27,7 @@ class MixedPrecisionQuantizationConfig: Args: compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer. distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric. + custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a float value for metric. If None, uses interest point metric. num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model. configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one. num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric. @@ -39,6 +40,7 @@ class MixedPrecisionQuantizationConfig: compute_distance_fn: Optional[Callable] = None distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG + custom_metric_fn: Optional[Callable] = None num_of_images: int = MP_DEFAULT_NUM_SAMPLES configuration_overwrite: Optional[List[int]] = None num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"}) diff --git a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py index a2c805403..4b14ca9e5 100644 --- a/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +++ b/model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py @@ -165,6 +165,7 @@ def compute_metric(cfg, node_idx=None, baseline_cfg=None): return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg), node_idx, topo_cfg(baseline_cfg) if baseline_cfg else None) + if self.using_virtual_graph: origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph( self.max_ru_config) diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py index 30cbf0574..b457d23ce 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py @@ -89,6 +89,9 @@ def __init__(self, self.interest_points = get_mp_interest_points(graph, fw_impl.count_node_for_mixed_precision_interest_points, quant_config.num_interest_points_factor) + # If using a custom metric - return only model outputs + if self.quant_config.custom_metric_fn is not None: + self.interest_points = [] # We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed # because hessian weights already do normalization. @@ -96,6 +99,9 @@ def __init__(self, self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse) self.output_points = get_output_nodes_for_metric(graph) + # If using a custom metric - return all model outputs + if self.quant_config.custom_metric_fn is not None: + self.output_points = [n.node for n in graph.get_outputs()] self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points, use_normalized_mse) @@ -160,7 +166,7 @@ def compute_metric(self, """ Compute the sensitivity metric of the MP model for a given configuration (the sensitivity is computed based on the similarity of the interest points' outputs between the MP model - and the float model). + and the float model or a custom metric if given). Args: mp_model_configuration: Bitwidth configuration to use to configure the MP model. @@ -177,15 +183,21 @@ def compute_metric(self, node_idx) # Compute the distance metric - ipts_distances, out_pts_distances = self._compute_distance() + if self.quant_config.custom_metric_fn is None: + ipts_distances, out_pts_distances = self._compute_distance() + sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances, + self.quant_config.distance_weighting_method) + else: + sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp) + if not isinstance(sensitivity_metric, (float, np.floating)): + raise TypeError(f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}') # Configure MP model back to the same configuration as the baseline model if baseline provided if baseline_mp_configuration is not None: self._configure_bitwidths_model(baseline_mp_configuration, node_idx) - return self._compute_mp_distance_measure(ipts_distances, out_pts_distances, - self.quant_config.distance_weighting_method) + return sensitivity_metric def _init_baseline_tensors_list(self): """ diff --git a/tests_pytest/common_tests/unit_tests/core/mixed_precision/test_custom_metric_function.py b/tests_pytest/common_tests/unit_tests/core/mixed_precision/test_custom_metric_function.py new file mode 100644 index 000000000..1ca7a7f8c --- /dev/null +++ b/tests_pytest/common_tests/unit_tests/core/mixed_precision/test_custom_metric_function.py @@ -0,0 +1,101 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import numpy as np + +from unittest.mock import Mock + +from model_compression_toolkit.core import MixedPrecisionQuantizationConfig +from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation +from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode + + +def custom_float_metric(model_mp) -> float: + return 100.0 + + +def custom_np_float_metric(model_mp) -> np.floating: + return np.float64(100.0) + + +def custom_str_metric(model_mp) -> str: + return 'test' + + +def custom_none_metric(model_mp): + return None + + +def get_sensitivity_evaluator(custom_metric_fn): + mock_graph = Mock() + mock_graph.get_topo_sorted_nodes.return_value = ['test', 'this', 'is', 'reset'] + mock_graph.get_outputs.return_value = [] + + def representative_data_gen() -> list: + for _ in range(5): + yield np.random.randn(2, 3, 248, 248) + + mock_fw_info = Mock() + + def custom_model_builder_return_value(*args, **kwargs): + mode = kwargs.get('mode') + if mode == ModelBuilderMode.FLOAT: + return (None, None) + else: + return (None, None, None) + + def custom_to_tensor(img): + return img + + mock_fw_impl = Mock() + mock_fw_impl.model_builder.side_effect = custom_model_builder_return_value + mock_fw_impl.to_tensor.side_effect = custom_to_tensor + + mock_set_layer_to_bitwidth = Mock() + + mp_cfg = MixedPrecisionQuantizationConfig(custom_metric_fn=custom_metric_fn) + + sensitivity_eval = SensitivityEvaluation(graph=mock_graph, + quant_config=mp_cfg, + representative_data_gen=representative_data_gen, + fw_info=mock_fw_info, + fw_impl=mock_fw_impl, + set_layer_to_bitwidth=mock_set_layer_to_bitwidth + ) + sensitivity_eval._configure_bitwidths_model = lambda *args, **kwargs: None # Method does nothing + sensitivity_eval.model_mp = Mock() + return sensitivity_eval + + +class TestMPCustomMetricFunction: + + @pytest.mark.parametrize("metric_fn, expected", [ + (custom_float_metric, 100.0), + (custom_np_float_metric, np.float64(100.0)), + ]) + def test_valid_metric_function(self, metric_fn, expected): + sensitivity_eval = get_sensitivity_evaluator(metric_fn) + assert len(sensitivity_eval.interest_points) == 0 + assert sensitivity_eval.compute_metric(Mock()) == expected + + @pytest.mark.parametrize("metric_fn, expected", [ + (custom_str_metric, str.__name__), + (custom_none_metric, type(None).__name__), + ]) + def test_type_invalid_metric_function(self, metric_fn, expected): + sensitivity_eval = get_sensitivity_evaluator(metric_fn) + assert len(sensitivity_eval.interest_points) == 0 + with pytest.raises(TypeError, match=f'The custom_metric_fn is expected to return float or numpy float, got {expected}'): + sensitivity_metric = sensitivity_eval.compute_metric(Mock())