@@ -89,13 +89,19 @@ def __init__(self,
8989 self .interest_points = get_mp_interest_points (graph ,
9090 fw_impl .count_node_for_mixed_precision_interest_points ,
9191 quant_config .num_interest_points_factor )
92+ # If using a custom metric - return only model outputs
93+ if self .quant_config .custom_metric_fn is not None :
94+ self .interest_points = []
9295
9396 # We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
9497 # because hessian weights already do normalization.
9598 use_normalized_mse = self .quant_config .use_hessian_based_scores is False
9699 self .ips_distance_fns , self .ips_axis = self ._init_metric_points_lists (self .interest_points , use_normalized_mse )
97100
98101 self .output_points = get_output_nodes_for_metric (graph )
102+ # If using a custom metric - return all model outputs
103+ if self .quant_config .custom_metric_fn is not None :
104+ self .output_points = [n .node for n in graph .get_outputs ()]
99105 self .out_ps_distance_fns , self .out_ps_axis = self ._init_metric_points_lists (self .output_points ,
100106 use_normalized_mse )
101107
@@ -187,6 +193,38 @@ def compute_metric(self,
187193 return self ._compute_mp_distance_measure (ipts_distances , out_pts_distances ,
188194 self .quant_config .distance_weighting_method )
189195
196+ def compute_custom_metric (self ,
197+ mp_model_configuration : List [int ],
198+ node_idx : List [int ] = None ,
199+ baseline_mp_configuration : List [int ] = None ) -> float :
200+ """
201+ Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
202+ is computed on a custom function).
203+
204+ Args:
205+ mp_model_configuration: Bitwidth configuration to use to configure the MP model.
206+ node_idx: A list of nodes' indices to configure (instead of using the entire mp_model_configuration).
207+ baseline_mp_configuration: A mixed-precision configuration to set the model back to after modifying it to
208+ compute the metric for the given configuration.
209+
210+ Returns:
211+ The sensitivity metric of the MP model for a given configuration.
212+ """
213+
214+ # Configure MP model with the given configuration.
215+ self ._configure_bitwidths_model (mp_model_configuration ,
216+ node_idx )
217+
218+ # Compute the distance metric
219+ sensitivity_metric = self .quant_config .custom_metric_fn (self .model_mp )
220+
221+ # Configure MP model back to the same configuration as the baseline model if baseline provided
222+ if baseline_mp_configuration is not None :
223+ self ._configure_bitwidths_model (baseline_mp_configuration ,
224+ node_idx )
225+
226+ return sensitivity_metric
227+
190228 def _init_baseline_tensors_list (self ):
191229 """
192230 Evaluates the baseline model on all images and returns the obtained lists of tensors in a list for later use.
0 commit comments