Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MixedPrecisionQuantizationConfig:
Args:
compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a float value for metric. If None, uses interest point metric.
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
Expand All @@ -39,6 +40,7 @@ class MixedPrecisionQuantizationConfig:

compute_distance_fn: Optional[Callable] = None
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
custom_metric_fn: Optional[Callable] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add description to docstring, including the expected api of the function (args it accepts and what it should return).

num_of_images: int = MP_DEFAULT_NUM_SAMPLES
configuration_overwrite: Optional[List[int]] = None
num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def compute_metric(cfg, node_idx=None, baseline_cfg=None):
return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
node_idx,
topo_cfg(baseline_cfg) if baseline_cfg else None)

if self.using_virtual_graph:
origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
self.max_ru_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,19 @@ def __init__(self,
self.interest_points = get_mp_interest_points(graph,
fw_impl.count_node_for_mixed_precision_interest_points,
quant_config.num_interest_points_factor)
# If using a custom metric - return only model outputs
if self.quant_config.custom_metric_fn is not None:
self.interest_points = []

# We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
# because hessian weights already do normalization.
use_normalized_mse = self.quant_config.use_hessian_based_scores is False
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)

self.output_points = get_output_nodes_for_metric(graph)
# If using a custom metric - return all model outputs
if self.quant_config.custom_metric_fn is not None:
self.output_points = [n.node for n in graph.get_outputs()]
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points,
use_normalized_mse)

Expand Down Expand Up @@ -160,7 +166,7 @@ def compute_metric(self,
"""
Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
is computed based on the similarity of the interest points' outputs between the MP model
and the float model).
and the float model or a custom metric if given).

Args:
mp_model_configuration: Bitwidth configuration to use to configure the MP model.
Expand All @@ -177,15 +183,21 @@ def compute_metric(self,
node_idx)

# Compute the distance metric
ipts_distances, out_pts_distances = self._compute_distance()
if self.quant_config.custom_metric_fn is None:
ipts_distances, out_pts_distances = self._compute_distance()
sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
self.quant_config.distance_weighting_method)
else:
sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
if not isinstance(sensitivity_metric, (float, np.floating)):
raise TypeError(f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')

# Configure MP model back to the same configuration as the baseline model if baseline provided
if baseline_mp_configuration is not None:
self._configure_bitwidths_model(baseline_mp_configuration,
node_idx)

return self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
self.quant_config.distance_weighting_method)
return sensitivity_metric

def _init_baseline_tensors_list(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import numpy as np

from unittest.mock import Mock

from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode


def custom_float_metric(model_mp) -> float:
return 100.0


def custom_np_float_metric(model_mp) -> np.floating:
return np.float64(100.0)


def custom_str_metric(model_mp) -> str:
return 'test'


def custom_none_metric(model_mp):
return None


def get_sensitivity_evaluator(custom_metric_fn):
mock_graph = Mock()
mock_graph.get_topo_sorted_nodes.return_value = ['test', 'this', 'is', 'reset']
mock_graph.get_outputs.return_value = []

def representative_data_gen() -> list:
for _ in range(5):
yield np.random.randn(2, 3, 248, 248)

mock_fw_info = Mock()

def custom_model_builder_return_value(*args, **kwargs):
mode = kwargs.get('mode')
if mode == ModelBuilderMode.FLOAT:
return (None, None)
else:
return (None, None, None)

def custom_to_tensor(img):
return img

mock_fw_impl = Mock()
mock_fw_impl.model_builder.side_effect = custom_model_builder_return_value
mock_fw_impl.to_tensor.side_effect = custom_to_tensor

mock_set_layer_to_bitwidth = Mock()

mp_cfg = MixedPrecisionQuantizationConfig(custom_metric_fn=custom_metric_fn)

sensitivity_eval = SensitivityEvaluation(graph=mock_graph,
quant_config=mp_cfg,
representative_data_gen=representative_data_gen,
fw_info=mock_fw_info,
fw_impl=mock_fw_impl,
set_layer_to_bitwidth=mock_set_layer_to_bitwidth
)
sensitivity_eval._configure_bitwidths_model = lambda *args, **kwargs: None # Method does nothing
sensitivity_eval.model_mp = Mock()
return sensitivity_eval


class TestMPCustomMetricFunction:

@pytest.mark.parametrize("metric_fn, expected", [
(custom_float_metric, 100.0),
(custom_np_float_metric, np.float64(100.0)),
])
def test_valid_metric_function(self, metric_fn, expected):
sensitivity_eval = get_sensitivity_evaluator(metric_fn)
assert len(sensitivity_eval.interest_points) == 0
assert sensitivity_eval.compute_metric(Mock()) == expected

@pytest.mark.parametrize("metric_fn, expected", [
(custom_str_metric, str.__name__),
(custom_none_metric, type(None).__name__),
])
def test_type_invalid_metric_function(self, metric_fn, expected):
sensitivity_eval = get_sensitivity_evaluator(metric_fn)
assert len(sensitivity_eval.interest_points) == 0
with pytest.raises(TypeError, match=f'The custom_metric_fn is expected to return float or numpy float, got {expected}'):
sensitivity_metric = sensitivity_eval.compute_metric(Mock())