Skip to content

Commit 09ed05c

Browse files
authored
Use float mp model baseline instead of maxbit configuration (#1441)
* float mp model for sensitivity - disable all quantizers except current candidate, instead of using maxbit config * ensure non-negative value in similarity analyzer (due to float precision) * add normalization method and epsilon for sensitivity metric to MP config * print only changed by refinement solutions instead of all * remove support for non-configurable quant layer in mp back2framework
1 parent 77e3d11 commit 09ed05c

File tree

17 files changed

+954
-368
lines changed

17 files changed

+954
-368
lines changed

model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,23 @@
1414
# ==============================================================================
1515

1616
from dataclasses import dataclass, field
17+
from enum import Enum
1718
from typing import List, Callable, Optional
1819
from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
1920
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
2021

2122

23+
class MpMetricNormalization(Enum):
24+
"""
25+
MAXBIT: normalize sensitivity metrics of layer candidates by max-bitwidth candidate (of that layer).
26+
MINBIT: normalize sensitivity metrics of layer candidates by min-bitwidth candidate (of that layer).
27+
NONE: no normalization.
28+
"""
29+
MAXBIT = 'MAXBIT'
30+
MINBIT = 'MINBIT'
31+
NONE = 'NONE'
32+
33+
2234
@dataclass
2335
class MixedPrecisionQuantizationConfig:
2436
"""
@@ -27,7 +39,6 @@ class MixedPrecisionQuantizationConfig:
2739
Args:
2840
compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
2941
distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
30-
custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a float value for metric. If None, uses interest point metric.
3142
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
3243
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
3344
num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
@@ -36,11 +47,16 @@ class MixedPrecisionQuantizationConfig:
3647
refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
3748
metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
3849
hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective.
39-
"""
50+
metric_normalization (MpMetricNormalization): Metric normalization method.
51+
metric_epsilon (float | None): ensure minimal distance between the metric for any non-max-bidwidth candidate
52+
and a max-bitwidth candidate, i.e. metric(non-max-bitwidth) >= metric(max-bitwidth) + epsilon.
53+
If none, the computed metrics are used as is.
54+
custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a
55+
float value for metric. If None, uses interest point metric.
4056
57+
"""
4158
compute_distance_fn: Optional[Callable] = None
4259
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
43-
custom_metric_fn: Optional[Callable] = None
4460
num_of_images: int = MP_DEFAULT_NUM_SAMPLES
4561
configuration_overwrite: Optional[List[int]] = None
4662
num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})
@@ -49,6 +65,9 @@ class MixedPrecisionQuantizationConfig:
4965
refine_mp_solution: bool = True
5066
metric_normalization_threshold: float = 1e10
5167
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
68+
metric_normalization: MpMetricNormalization = MpMetricNormalization.NONE
69+
metric_epsilon: Optional[float] = 1e-6
70+
custom_metric_fn: Optional[Callable] = None
5271
_is_mixed_precision_enabled: bool = field(init=False, default=False)
5372

5473
def __post_init__(self):

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@ def search_bit_width(graph: Graph,
9898

9999
# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
100100
search_manager = MixedPrecisionSearchManager(graph,
101-
fw_info,
102-
fw_impl,
103-
se,
104-
target_resource_utilization)
101+
fw_info=fw_info,
102+
fw_impl=fw_impl,
103+
sensitivity_evaluator=se,
104+
target_resource_utilization=target_resource_utilization,
105+
mp_config=mp_config)
105106
nodes_bit_cfg = search_manager.search()
106107

107108
graph.skip_validation_check = False

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
import os
16+
1517
import itertools
1618

1719
import copy
1820
from collections import defaultdict
1921

2022
from tqdm import tqdm
2123

22-
from typing import Dict, List, Tuple, Optional
24+
from typing import Dict, List, Tuple, Optional, Set
2325

2426
import numpy as np
2527

@@ -40,6 +42,8 @@
4042
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
4143
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
4244
from model_compression_toolkit.logger import Logger
45+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
46+
MixedPrecisionQuantizationConfig, MpMetricNormalization
4347

4448

4549
class MixedPrecisionSearchManager:
@@ -52,7 +56,8 @@ def __init__(self,
5256
fw_info: FrameworkInfo,
5357
fw_impl: FrameworkImplementation,
5458
sensitivity_evaluator: SensitivityEvaluation,
55-
target_resource_utilization: ResourceUtilization):
59+
target_resource_utilization: ResourceUtilization,
60+
mp_config: MixedPrecisionQuantizationConfig):
5661
"""
5762
5863
Args:
@@ -74,21 +79,21 @@ def __init__(self,
7479

7580
self.sensitivity_evaluator = sensitivity_evaluator
7681
self.target_resource_utilization = target_resource_utilization
82+
self.mp_config = mp_config
7783

7884
self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
7985

8086
self.ru_targets = target_resource_utilization.get_restricted_targets()
81-
self.ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
87+
self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
8288

8389
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
84-
self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info)
8590

86-
self.config_reconstruction_helper = ConfigReconstructionHelper(self.original_graph)
91+
self.config_reconstructor = None
92+
orig_min_config = self.min_ru_config
8793
if self.using_virtual_graph:
88-
real_min_ru_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.min_ru_config)
89-
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, real_min_ru_config)
90-
else:
91-
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
94+
self.config_reconstructor = ConfigReconstructionHelper(self.original_graph)
95+
orig_min_config = self.config_reconstructor.reconstruct_full_configuration(self.min_ru_config)
96+
self.min_ru = self.orig_graph_ru_helper.compute_utilization(self.ru_targets, orig_min_config)
9297

9398
def search(self) -> Dict[BaseNode, int]:
9499
"""
@@ -100,7 +105,7 @@ def search(self) -> Dict[BaseNode, int]:
100105
mp_config = self._prepare_and_run_solver()
101106

102107
if self.using_virtual_graph:
103-
mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_config)
108+
mp_config = self.config_reconstructor.reconstruct_full_configuration(mp_config)
104109

105110
return mp_config
106111

@@ -143,61 +148,64 @@ def _get_relative_ru_constraint_per_mem_element(self) -> Dict[RUTarget, np.ndarr
143148
f"following targets: {unsatisfiable_targets}")
144149
return rel_target_ru
145150

146-
def _build_sensitivity_mapping(self, eps: float = 1e-6) -> Dict[BaseNode, List[float]]:
151+
def _build_sensitivity_mapping(self) -> Dict[BaseNode, List[float]]:
147152
"""
148153
This function measures the sensitivity of a change in a bitwidth of a layer on the entire model.
149154
150-
Args:
151-
eps: if sensitivity for a non-max candidate is lower than for a max candidate, we set it to
152-
sensitivity of a max candidate + epsilon.
153-
154155
Returns:
155156
Mapping from nodes to their bitwidth candidates sensitivity.
156157
"""
157-
158158
Logger.info('Starting to evaluate metrics')
159-
160-
orig_sorted_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
161-
162-
def topo_cfg(cfg: dict) -> list:
163-
topo_cfg = [cfg[n] for n in orig_sorted_nodes]
164-
assert len(topo_cfg) == len(cfg)
165-
return topo_cfg
166-
167-
def compute_metric(cfg, node_idx=None, baseline_cfg=None):
168-
return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
169-
node_idx,
170-
topo_cfg(baseline_cfg) if baseline_cfg else None)
171-
172-
if self.using_virtual_graph:
173-
origin_max_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.max_ru_config)
174-
max_config_value = compute_metric(origin_max_config)
175-
else:
176-
max_config_value = compute_metric(self.max_ru_config)
159+
norm_method = self.mp_config.metric_normalization
160+
eps = self.mp_config.metric_epsilon
161+
162+
verbose = 'VERBOSE_MP_METRIC' in os.environ
163+
164+
def normalize(node_candidates_metrics, max_ind):
165+
if norm_method == MpMetricNormalization.NONE:
166+
return node_candidates_metrics
167+
if norm_method == MpMetricNormalization.MAXBIT:
168+
ref_ind = max_ind
169+
elif norm_method == MpMetricNormalization.MINBIT:
170+
ref_ind = node.find_min_candidate_index()
171+
else: # pragma: no cover
172+
raise ValueError(f'Unexpected MpMetricNormalization mode {norm_method}')
173+
normalized_metrics = node_candidates_metrics / node_candidates_metrics[ref_ind]
174+
if verbose and not np.array_equal(normalized_metrics, node_candidates_metrics):
175+
print(f'{"normalized metric:":25}', candidates_sensitivity)
176+
return normalized_metrics
177+
178+
def ensure_maxbit_minimal_metric(node_candidates_metrics, max_ind):
179+
if eps is None:
180+
return node_candidates_metrics
181+
# We want maxbit configuration to have the minimal distance metric (so that optimization objective
182+
# doesn't prefer lower bits). If we got a smaller metric for non-maxbit, we update it to metric(maxbit)+eps.
183+
max_val = node_candidates_metrics[max_ind]
184+
metrics = np.maximum(node_candidates_metrics, max_val + eps)
185+
metrics[max_ind] = max_val
186+
if verbose and not np.array_equal(metrics, node_candidates_metrics):
187+
print(f'{"eps-adjusted metric:":25}', candidates_sensitivity)
188+
return metrics
177189

178190
layer_to_metrics_mapping = defaultdict(list)
179191
for node_idx, node in tqdm(enumerate(self.mp_topo_configurable_nodes)):
192+
candidates_sensitivity = np.empty(len(node.candidates_quantization_cfg))
180193
for bitwidth_idx, _ in enumerate(node.candidates_quantization_cfg):
181-
if self.max_ru_config[node] == bitwidth_idx:
182-
# This is a computation of the metric for the max configuration, assign pre-calculated value
183-
layer_to_metrics_mapping[node].append(max_config_value)
184-
continue
185-
186-
# Create a configuration that differs at one layer only from the baseline model
187-
mp_model_configuration = self.max_ru_config.copy()
188-
mp_model_configuration[node] = bitwidth_idx
189-
190-
# Build a distance matrix using the function we got from the framework implementation.
191194
if self.using_virtual_graph:
192-
# Reconstructing original graph's configuration from virtual graph's configuration
193-
orig_mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_model_configuration)
194-
changed_nodes = [orig_sorted_nodes.index(n) for n, ind in orig_mp_config.items()
195-
if origin_max_config[n] != ind]
196-
metric_value = compute_metric(orig_mp_config, changed_nodes, origin_max_config)
195+
a_cfg, w_cfg = self.config_reconstructor.reconstruct_separate_aw_configs({node: bitwidth_idx})
197196
else:
198-
metric_value = compute_metric(mp_model_configuration, [node_idx], self.max_ru_config)
199-
metric_value = max(metric_value, max_config_value + eps)
200-
layer_to_metrics_mapping[node].append(metric_value)
197+
a_cfg = {node: bitwidth_idx} if node.has_configurable_activation() else {}
198+
w_cfg = {node: bitwidth_idx} if node.has_any_configurable_weight() else {}
199+
candidates_sensitivity[bitwidth_idx] = self.sensitivity_evaluator.compute_metric(
200+
mp_a_cfg={n.name: ind for n, ind in a_cfg.items()},
201+
mp_w_cfg={n.name: ind for n, ind in w_cfg.items()}
202+
)
203+
if verbose:
204+
print(f'{node.name}\n{"raw metric:":25}', candidates_sensitivity)
205+
max_ind = node.find_max_candidate_index()
206+
candidates_sensitivity = normalize(candidates_sensitivity, max_ind)
207+
candidates_sensitivity = ensure_maxbit_minimal_metric(candidates_sensitivity, max_ind)
208+
layer_to_metrics_mapping[node] = candidates_sensitivity
201209

202210
# Finalize distance metric mapping
203211
self._finalize_distance_metric(layer_to_metrics_mapping)
@@ -244,8 +252,9 @@ def _compute_relative_ru_matrices(self) -> Dict[RUTarget, np.ndarray]:
244252
else:
245253
cfg = self.min_ru_config.copy()
246254
cfg[node] = candidate_idx
247-
real_cfg = self.config_reconstruction_helper.reconstruct_full_configuration(cfg)
248-
candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, real_cfg)
255+
if self.using_virtual_graph:
256+
cfg = self.config_reconstructor.reconstruct_full_configuration(cfg)
257+
candidate_rus = self.orig_graph_ru_helper.compute_utilization(self.ru_targets, cfg)
249258

250259
for target, ru in candidate_rus.items():
251260
rus_per_candidate[target].append(ru)
@@ -283,8 +292,8 @@ def compute_resource_utilization_for_config(self, config: Dict[BaseNode, int]) -
283292
with the given config.
284293
285294
"""
286-
act_qcs, w_qcs = self.ru_helper.get_quantization_candidates(config)
287-
ru = self.ru_helper.ru_calculator.compute_resource_utilization(
295+
act_qcs, w_qcs = self.orig_graph_ru_helper.get_quantization_candidates(config)
296+
ru = self.orig_graph_ru_helper.ru_calculator.compute_resource_utilization(
288297
target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
289298
w_qcs=w_qcs, ru_targets=self.ru_targets, allow_unused_qcs=True)
290299
return ru
@@ -303,7 +312,7 @@ def _finalize_distance_metric(self, layer_to_metrics_mapping: Dict[BaseNode, Lis
303312
# normalize metric for numerical stability
304313
max_dist = max(itertools.chain.from_iterable(layer_to_metrics_mapping.values()))
305314

306-
if max_dist >= self.sensitivity_evaluator.quant_config.metric_normalization_threshold:
315+
if max_dist >= self.mp_config.metric_normalization_threshold:
307316
Logger.warning(f"The mixed precision distance metric values indicate a large error in the quantized model."
308317
f"this can cause numerical issues."
309318
f"The program will proceed with mixed precision search after scaling the metric values,"
@@ -387,7 +396,9 @@ def reconstruct_full_configuration(self,
387396

388397
return orig_cfg
389398

390-
def reconstruct_separate_aw_configs(self, virtual_cfg: Dict[BaseNode, int], include_non_configurable: bool) \
399+
def reconstruct_separate_aw_configs(self,
400+
virtual_cfg: Dict[BaseNode, int],
401+
include_non_configurable: bool = False) \
391402
-> Tuple[Dict[BaseNode, int], Dict[BaseNode, int]]:
392403
"""
393404
Retrieves original activation and weights nodes and corresponding candidates for a given configuration of the

0 commit comments

Comments
 (0)