Skip to content

Commit eb7d615

Browse files
committed
add support for custom metric function for mixed precision
1 parent 018b352 commit eb7d615

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class MixedPrecisionQuantizationConfig:
3939

4040
compute_distance_fn: Optional[Callable] = None
4141
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
42+
custom_metric_fn: Optional[Callable] = None
4243
num_of_images: int = MP_DEFAULT_NUM_SAMPLES
4344
configuration_overwrite: Optional[List[int]] = None
4445
num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,15 @@ def topo_cfg(cfg: dict) -> list:
162162
return topo_cfg
163163

164164
def compute_metric(cfg, node_idx=None, baseline_cfg=None):
165-
return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
166-
node_idx,
167-
topo_cfg(baseline_cfg) if baseline_cfg else None)
165+
if self.sensitivity_evaluator.quant_config.custom_metric_fn is None:
166+
return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
167+
node_idx,
168+
topo_cfg(baseline_cfg) if baseline_cfg else None)
169+
else:
170+
return self.sensitivity_evaluator.compute_custom_metric(topo_cfg(cfg),
171+
node_idx,
172+
topo_cfg(baseline_cfg) if baseline_cfg else None)
173+
168174
if self.using_virtual_graph:
169175
origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
170176
self.max_ru_config)

model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,19 @@ def __init__(self,
8989
self.interest_points = get_mp_interest_points(graph,
9090
fw_impl.count_node_for_mixed_precision_interest_points,
9191
quant_config.num_interest_points_factor)
92+
# If using a custom metric - return only model outputs
93+
if self.quant_config.custom_metric_fn is not None:
94+
self.interest_points = []
9295

9396
# We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
9497
# because hessian weights already do normalization.
9598
use_normalized_mse = self.quant_config.use_hessian_based_scores is False
9699
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)
97100

98101
self.output_points = get_output_nodes_for_metric(graph)
102+
# If using a custom metric - return all model outputs
103+
if self.quant_config.custom_metric_fn is not None:
104+
self.output_points = [n.node for n in graph.get_outputs()]
99105
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points,
100106
use_normalized_mse)
101107

@@ -187,6 +193,38 @@ def compute_metric(self,
187193
return self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
188194
self.quant_config.distance_weighting_method)
189195

196+
def compute_custom_metric(self,
197+
mp_model_configuration: List[int],
198+
node_idx: List[int] = None,
199+
baseline_mp_configuration: List[int] = None) -> float:
200+
"""
201+
Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
202+
is computed on a custom function).
203+
204+
Args:
205+
mp_model_configuration: Bitwidth configuration to use to configure the MP model.
206+
node_idx: A list of nodes' indices to configure (instead of using the entire mp_model_configuration).
207+
baseline_mp_configuration: A mixed-precision configuration to set the model back to after modifying it to
208+
compute the metric for the given configuration.
209+
210+
Returns:
211+
The sensitivity metric of the MP model for a given configuration.
212+
"""
213+
214+
# Configure MP model with the given configuration.
215+
self._configure_bitwidths_model(mp_model_configuration,
216+
node_idx)
217+
218+
# Compute the distance metric
219+
sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
220+
221+
# Configure MP model back to the same configuration as the baseline model if baseline provided
222+
if baseline_mp_configuration is not None:
223+
self._configure_bitwidths_model(baseline_mp_configuration,
224+
node_idx)
225+
226+
return sensitivity_metric
227+
190228
def _init_baseline_tensors_list(self):
191229
"""
192230
Evaluates the baseline model on all images and returns the obtained lists of tensors in a list for later use.

0 commit comments

Comments
 (0)