1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15+ import contextlib
1516import copy
1617import itertools
1718
1819import numpy as np
19- from typing import Callable , Any , List , Tuple
20+ from typing import Callable , Any , List , Tuple , Dict , Optional
2021
2122from model_compression_toolkit .core import FrameworkInfo , MixedPrecisionQuantizationConfig
2223from model_compression_toolkit .core .common import Graph , BaseNode
@@ -156,7 +157,7 @@ def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = Fal
156157 axis_list .append (axis if distance_fn == compute_kl_divergence else None )
157158 return distance_fns_list , axis_list
158159
159- def compute_metric (self , mp_a_cfg : Dict [str , int ], mp_w_cfg : Dict [str , int ]) -> float :
160+ def compute_metric (self , mp_a_cfg : Dict [str , Optional [ int ]] , mp_w_cfg : Dict [str , Optional [ int ] ]) -> float :
160161 """
161162 Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
162163 is computed based on the similarity of the interest points' outputs between the MP model
@@ -171,21 +172,24 @@ def compute_metric(self, mp_a_cfg: Dict[str, int], mp_w_cfg: Dict[str, int]) ->
171172 The sensitivity metric of the MP model for a given configuration.
172173 """
173174
174- # Configure MP model with the given configuration.
175- self ._configure_bitwidths_model ( mp_a_cfg , mp_w_cfg )
175+ with self . _configured_mp_model ( mp_a_cfg , mp_w_cfg ):
176+ sensitivity_metric = self ._compute_metric ( )
176177
177- # Compute the distance metric
178- if self .quant_config .custom_metric_fn is None :
179- ipts_distances , out_pts_distances = self ._compute_distance ()
180- sensitivity_metric = self ._compute_mp_distance_measure (ipts_distances , out_pts_distances ,
181- self .quant_config .distance_weighting_method )
182- else :
178+ return sensitivity_metric
179+
180+ def _compute_metric (self ):
181+ """ Compute sensitivity metric on a configured mp model. """
182+ if self .quant_config .custom_metric_fn :
183183 sensitivity_metric = self .quant_config .custom_metric_fn (self .model_mp )
184184 if not isinstance (sensitivity_metric , (float , np .floating )):
185- raise TypeError (f'The custom_metric_fn is expected to return float or numpy float, got { type (sensitivity_metric ).__name__ } ' )
186-
187- # restore configured nodes back to float
188- self ._configure_bitwidths_model ({n : None for n in mp_a_cfg }, {n : None for n in mp_w_cfg })
185+ raise TypeError (
186+ f'The custom_metric_fn is expected to return float or numpy float, got { type (sensitivity_metric ).__name__ } ' )
187+ return sensitivity_metric
188+
189+ # compute default metric
190+ ipts_distances , out_pts_distances = self ._compute_distance ()
191+ sensitivity_metric = self ._compute_mp_distance_measure (ipts_distances , out_pts_distances ,
192+ self .quant_config .distance_weighting_method )
189193 return sensitivity_metric
190194
191195 def _init_baseline_tensors_list (self ):
@@ -206,7 +210,8 @@ def _build_models(self) -> Any:
206210
207211 evaluation_graph = copy .deepcopy (self .graph )
208212
209- # Disable quantization for non-configurable nodes, and, if requested, for all activations.
213+ # Disable quantization for non-configurable nodes, and, if requested, for all activations (quantizers won't
214+ # be added to the model).
210215 for n in evaluation_graph .get_topo_sorted_nodes ():
211216 if self .disable_activation_for_metric or not n .has_configurable_activation ():
212217 for c in n .candidates_quantization_cfg :
@@ -261,35 +266,46 @@ def _compute_hessian_based_scores(self) -> np.ndarray:
261266 # Return the mean approximation value across all images for each interest point
262267 return np .mean (approx_by_image , axis = 0 )
263268
264- def _configure_bitwidths_model (self , mp_a_cfg : Dict [str , int ], mp_w_cfg : Dict [str , int ]):
269+ @contextlib .contextmanager
270+ def _configured_mp_model (self , mp_a_cfg : Dict [str , Optional [int ]], mp_w_cfg : Dict [str , Optional [int ]]):
265271 """
266- Configure specific configurable layers of the mp model.
272+ Context manager to configure specific configurable layers of the mp model. At exit, configuration is
273+ automatically restored to un-quantized.
267274
268275 Args:
269276 mp_a_cfg: Nodes bitwidth indices to configure activation quantizers to.
270277 mp_w_cfg: Nodes bitwidth indices to configure weights quantizers to.
271278
272279 """
273- node_names = set (mp_a_cfg .keys ()).union (set (mp_w_cfg .keys ()))
274- mp_a_cfg = copy .deepcopy (mp_a_cfg )
275- mp_w_cfg = copy .deepcopy (mp_w_cfg )
276-
277- for n in node_names :
278- node_quant_layers = self .conf_node2layers .get (n )
279- if node_quant_layers is None : # pragma: no cover
280- raise ValueError (f"Matching layers for node { n } not found in the mixed precision model configuration." )
281- for qlayer in node_quant_layers :
282- assert isinstance (qlayer , (self .fw_impl .activation_quant_layer_cls ,
283- self .fw_impl .weights_quant_layer_cls )), f'{ type (qlayer )} of node { n } '
284- if isinstance (qlayer , self .fw_impl .activation_quant_layer_cls ) and n in mp_a_cfg :
285- set_activation_quant_layer_to_bitwidth (qlayer , mp_a_cfg [n ], self .fw_impl )
286- mp_a_cfg .pop (n )
287- elif isinstance (qlayer , self .fw_impl .weights_quant_layer_cls ) and n in mp_w_cfg :
288- set_weights_quant_layer_to_bitwidth (qlayer , mp_w_cfg [n ], self .fw_impl )
289- mp_w_cfg .pop (n )
290- if mp_a_cfg or mp_w_cfg :
291- raise ValueError (f'Not all mp configs were consumed, remaining activation config { mp_a_cfg } , '
292- f'weights config { mp_w_cfg } ' )
280+ if not (mp_a_cfg and any (v is not None for v in mp_a_cfg .values ()) or
281+ mp_w_cfg and any (v is not None for v in mp_w_cfg .values ())):
282+ raise ValueError (f'Requested configuration is either empty or contain only None values.' )
283+
284+ # defined here so that it can't be used directly
285+ def configure (a_cfg , w_cfg ):
286+ node_names = set (a_cfg .keys ()).union (set (w_cfg .keys ()))
287+ for n in node_names :
288+ node_quant_layers = self .conf_node2layers .get (n )
289+ if node_quant_layers is None : # pragma: no cover
290+ raise ValueError (f"Matching layers for node { n } not found in the mixed precision model configuration." )
291+ for qlayer in node_quant_layers :
292+ assert isinstance (qlayer , (self .fw_impl .activation_quant_layer_cls ,
293+ self .fw_impl .weights_quant_layer_cls )), f'Unexpected { type (qlayer )} of node { n } '
294+ if isinstance (qlayer , self .fw_impl .activation_quant_layer_cls ) and n in a_cfg :
295+ set_activation_quant_layer_to_bitwidth (qlayer , a_cfg [n ], self .fw_impl )
296+ a_cfg .pop (n )
297+ elif isinstance (qlayer , self .fw_impl .weights_quant_layer_cls ) and n in w_cfg :
298+ set_weights_quant_layer_to_bitwidth (qlayer , w_cfg [n ], self .fw_impl )
299+ w_cfg .pop (n )
300+ if a_cfg or w_cfg :
301+ raise ValueError (f'Not all mp configs were consumed, remaining activation config { a_cfg } , '
302+ f'weights config { w_cfg } .' )
303+
304+ configure (mp_a_cfg .copy (), mp_w_cfg .copy ())
305+ try :
306+ yield
307+ finally :
308+ configure ({n : None for n in mp_a_cfg }, {n : None for n in mp_w_cfg })
293309
294310 def _compute_points_distance (self ,
295311 baseline_tensors : List [Any ],
0 commit comments