Skip to content

Commit b7fddbc

Browse files
irenabirenab
authored andcommitted
add framework integration test for sensitivity
1 parent d1e0c47 commit b7fddbc

File tree

7 files changed

+487
-40
lines changed

7 files changed

+487
-40
lines changed

model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
import contextlib
1516
import copy
1617
import itertools
1718

1819
import numpy as np
19-
from typing import Callable, Any, List, Tuple
20+
from typing import Callable, Any, List, Tuple, Dict, Optional
2021

2122
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
2223
from model_compression_toolkit.core.common import Graph, BaseNode
@@ -156,7 +157,7 @@ def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = Fal
156157
axis_list.append(axis if distance_fn == compute_kl_divergence else None)
157158
return distance_fns_list, axis_list
158159

159-
def compute_metric(self, mp_a_cfg: Dict[str, int], mp_w_cfg: Dict[str, int]) -> float:
160+
def compute_metric(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]) -> float:
160161
"""
161162
Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
162163
is computed based on the similarity of the interest points' outputs between the MP model
@@ -171,21 +172,24 @@ def compute_metric(self, mp_a_cfg: Dict[str, int], mp_w_cfg: Dict[str, int]) ->
171172
The sensitivity metric of the MP model for a given configuration.
172173
"""
173174

174-
# Configure MP model with the given configuration.
175-
self._configure_bitwidths_model(mp_a_cfg, mp_w_cfg)
175+
with self._configured_mp_model(mp_a_cfg, mp_w_cfg):
176+
sensitivity_metric = self._compute_metric()
176177

177-
# Compute the distance metric
178-
if self.quant_config.custom_metric_fn is None:
179-
ipts_distances, out_pts_distances = self._compute_distance()
180-
sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
181-
self.quant_config.distance_weighting_method)
182-
else:
178+
return sensitivity_metric
179+
180+
def _compute_metric(self):
181+
""" Compute sensitivity metric on a configured mp model. """
182+
if self.quant_config.custom_metric_fn:
183183
sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
184184
if not isinstance(sensitivity_metric, (float, np.floating)):
185-
raise TypeError(f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
186-
187-
# restore configured nodes back to float
188-
self._configure_bitwidths_model({n: None for n in mp_a_cfg}, {n: None for n in mp_w_cfg})
185+
raise TypeError(
186+
f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
187+
return sensitivity_metric
188+
189+
# compute default metric
190+
ipts_distances, out_pts_distances = self._compute_distance()
191+
sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
192+
self.quant_config.distance_weighting_method)
189193
return sensitivity_metric
190194

191195
def _init_baseline_tensors_list(self):
@@ -206,7 +210,8 @@ def _build_models(self) -> Any:
206210

207211
evaluation_graph = copy.deepcopy(self.graph)
208212

209-
# Disable quantization for non-configurable nodes, and, if requested, for all activations.
213+
# Disable quantization for non-configurable nodes, and, if requested, for all activations (quantizers won't
214+
# be added to the model).
210215
for n in evaluation_graph.get_topo_sorted_nodes():
211216
if self.disable_activation_for_metric or not n.has_configurable_activation():
212217
for c in n.candidates_quantization_cfg:
@@ -261,35 +266,46 @@ def _compute_hessian_based_scores(self) -> np.ndarray:
261266
# Return the mean approximation value across all images for each interest point
262267
return np.mean(approx_by_image, axis=0)
263268

264-
def _configure_bitwidths_model(self, mp_a_cfg: Dict[str, int], mp_w_cfg: Dict[str, int]):
269+
@contextlib.contextmanager
270+
def _configured_mp_model(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]):
265271
"""
266-
Configure specific configurable layers of the mp model.
272+
Context manager to configure specific configurable layers of the mp model. At exit, configuration is
273+
automatically restored to un-quantized.
267274
268275
Args:
269276
mp_a_cfg: Nodes bitwidth indices to configure activation quantizers to.
270277
mp_w_cfg: Nodes bitwidth indices to configure weights quantizers to.
271278
272279
"""
273-
node_names = set(mp_a_cfg.keys()).union(set(mp_w_cfg.keys()))
274-
mp_a_cfg = copy.deepcopy(mp_a_cfg)
275-
mp_w_cfg = copy.deepcopy(mp_w_cfg)
276-
277-
for n in node_names:
278-
node_quant_layers = self.conf_node2layers.get(n)
279-
if node_quant_layers is None: # pragma: no cover
280-
raise ValueError(f"Matching layers for node {n} not found in the mixed precision model configuration.")
281-
for qlayer in node_quant_layers:
282-
assert isinstance(qlayer, (self.fw_impl.activation_quant_layer_cls,
283-
self.fw_impl.weights_quant_layer_cls)), f'{type(qlayer)} of node {n}'
284-
if isinstance(qlayer, self.fw_impl.activation_quant_layer_cls) and n in mp_a_cfg:
285-
set_activation_quant_layer_to_bitwidth(qlayer, mp_a_cfg[n], self.fw_impl)
286-
mp_a_cfg.pop(n)
287-
elif isinstance(qlayer, self.fw_impl.weights_quant_layer_cls) and n in mp_w_cfg:
288-
set_weights_quant_layer_to_bitwidth(qlayer, mp_w_cfg[n], self.fw_impl)
289-
mp_w_cfg.pop(n)
290-
if mp_a_cfg or mp_w_cfg:
291-
raise ValueError(f'Not all mp configs were consumed, remaining activation config {mp_a_cfg}, '
292-
f'weights config {mp_w_cfg}')
280+
if not (mp_a_cfg and any(v is not None for v in mp_a_cfg.values()) or
281+
mp_w_cfg and any(v is not None for v in mp_w_cfg.values())):
282+
raise ValueError(f'Requested configuration is either empty or contain only None values.')
283+
284+
# defined here so that it can't be used directly
285+
def configure(a_cfg, w_cfg):
286+
node_names = set(a_cfg.keys()).union(set(w_cfg.keys()))
287+
for n in node_names:
288+
node_quant_layers = self.conf_node2layers.get(n)
289+
if node_quant_layers is None: # pragma: no cover
290+
raise ValueError(f"Matching layers for node {n} not found in the mixed precision model configuration.")
291+
for qlayer in node_quant_layers:
292+
assert isinstance(qlayer, (self.fw_impl.activation_quant_layer_cls,
293+
self.fw_impl.weights_quant_layer_cls)), f'Unexpected {type(qlayer)} of node {n}'
294+
if isinstance(qlayer, self.fw_impl.activation_quant_layer_cls) and n in a_cfg:
295+
set_activation_quant_layer_to_bitwidth(qlayer, a_cfg[n], self.fw_impl)
296+
a_cfg.pop(n)
297+
elif isinstance(qlayer, self.fw_impl.weights_quant_layer_cls) and n in w_cfg:
298+
set_weights_quant_layer_to_bitwidth(qlayer, w_cfg[n], self.fw_impl)
299+
w_cfg.pop(n)
300+
if a_cfg or w_cfg:
301+
raise ValueError(f'Not all mp configs were consumed, remaining activation config {a_cfg}, '
302+
f'weights config {w_cfg}.')
303+
304+
configure(mp_a_cfg.copy(), mp_w_cfg.copy())
305+
try:
306+
yield
307+
finally:
308+
configure({n: None for n in mp_a_cfg}, {n: None for n in mp_w_cfg})
293309

294310
def _compute_points_distance(self,
295311
baseline_tensors: List[Any],

0 commit comments

Comments
 (0)