Skip to content

Commit 35ffd84

Browse files
committed
Merge branch 'main' into apply_quant_info_to_fusinginfo
2 parents 629f837 + 09ed05c commit 35ffd84

File tree

33 files changed

+1247
-867
lines changed

33 files changed

+1247
-867
lines changed

model_compression_toolkit/core/common/framework_implementation.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,31 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
from abc import ABC, abstractmethod
16-
from typing import Callable, Any, List, Tuple, Dict, Generator
16+
from typing import Callable, Any, List, Tuple, Generator, Type
1717

1818
import numpy as np
1919

2020
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
21-
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
2221
from model_compression_toolkit.core import common
2322
from model_compression_toolkit.core.common import BaseNode
24-
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
2523
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
2624
from model_compression_toolkit.core.common.graph.base_graph import Graph
27-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianInfoService
28-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
25+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest
2926
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
3027
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
3128
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
3229
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
33-
from model_compression_toolkit.core.common.user_info import UserInformation
3430

3531

3632
class FrameworkImplementation(ABC):
3733
"""
3834
An abstract class with abstract methods that should be implemented when supporting a new
3935
framework in MCT.
4036
"""
37+
weights_quant_layer_cls: Type
38+
activation_quant_layer_cls: Type
39+
configurable_weights_quantizer_cls: Type
40+
configurable_activation_quantizer_cls: Type
4141

4242
@property
4343
def constants(self):
@@ -327,33 +327,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
327327
f'framework\'s get_substitutions_after_second_moment_correction '
328328
f'method.') # pragma: no cover
329329

330-
@abstractmethod
331-
def get_sensitivity_evaluator(self,
332-
graph: Graph,
333-
quant_config: MixedPrecisionQuantizationConfig,
334-
representative_data_gen: Callable,
335-
fw_info: FrameworkInfo,
336-
hessian_info_service: HessianInfoService = None,
337-
disable_activation_for_metric: bool = False) -> SensitivityEvaluation:
338-
"""
339-
Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
340-
configuration (comparing to the float model).
341-
342-
Args:
343-
graph: Graph to build its float and mixed-precision models.
344-
quant_config: QuantizationConfig of how the model should be quantized.
345-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
346-
fw_info: FrameworkInfo object with information about the specific framework's model.
347-
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
348-
hessian_info_service: HessianInfoService to fetch information based on Hessian-approximation.
349-
350-
Returns:
351-
A function that computes the metric.
352-
"""
353-
354-
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
355-
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
356-
357330
def get_node_prior_info(self, node: BaseNode,
358331
fw_info: FrameworkInfo,
359332
graph: Graph) -> NodePriorInfo:

model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,23 @@
1414
# ==============================================================================
1515

1616
from dataclasses import dataclass, field
17+
from enum import Enum
1718
from typing import List, Callable, Optional
1819
from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
1920
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
2021

2122

23+
class MpMetricNormalization(Enum):
24+
"""
25+
MAXBIT: normalize sensitivity metrics of layer candidates by max-bitwidth candidate (of that layer).
26+
MINBIT: normalize sensitivity metrics of layer candidates by min-bitwidth candidate (of that layer).
27+
NONE: no normalization.
28+
"""
29+
MAXBIT = 'MAXBIT'
30+
MINBIT = 'MINBIT'
31+
NONE = 'NONE'
32+
33+
2234
@dataclass
2335
class MixedPrecisionQuantizationConfig:
2436
"""
@@ -27,7 +39,6 @@ class MixedPrecisionQuantizationConfig:
2739
Args:
2840
compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
2941
distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
30-
custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a float value for metric. If None, uses interest point metric.
3142
num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
3243
configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
3344
num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
@@ -36,11 +47,16 @@ class MixedPrecisionQuantizationConfig:
3647
refine_mp_solution (bool): Whether to try to improve the final mixed-precision configuration using a greedy algorithm that searches layers to increase their bit-width, or not.
3748
metric_normalization_threshold (float): A threshold for checking the mixed precision distance metric values, In case of values larger than this threshold, the metric will be scaled to prevent numerical issues.
3849
hessian_batch_size (int): The Hessian computation batch size. used only if using mixed precision with Hessian-based objective.
39-
"""
50+
metric_normalization (MpMetricNormalization): Metric normalization method.
51+
metric_epsilon (float | None): ensure minimal distance between the metric for any non-max-bidwidth candidate
52+
and a max-bitwidth candidate, i.e. metric(non-max-bitwidth) >= metric(max-bitwidth) + epsilon.
53+
If none, the computed metrics are used as is.
54+
custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a
55+
float value for metric. If None, uses interest point metric.
4056
57+
"""
4158
compute_distance_fn: Optional[Callable] = None
4259
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
43-
custom_metric_fn: Optional[Callable] = None
4460
num_of_images: int = MP_DEFAULT_NUM_SAMPLES
4561
configuration_overwrite: Optional[List[int]] = None
4662
num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})
@@ -49,6 +65,9 @@ class MixedPrecisionQuantizationConfig:
4965
refine_mp_solution: bool = True
5066
metric_normalization_threshold: float = 1e10
5167
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
68+
metric_normalization: MpMetricNormalization = MpMetricNormalization.NONE
69+
metric_epsilon: Optional[float] = 1e-6
70+
custom_metric_fn: Optional[Callable] = None
5271
_is_mixed_precision_enabled: bool = field(init=False, default=False)
5372

5473
def __post_init__(self):

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MixedPrecisionSearchManager
2626
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
2727
ResourceUtilization
28+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
2829
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
2930
greedy_solution_refinement_procedure
3031

@@ -78,11 +79,12 @@ def search_bit_width(graph: Graph,
7879

7980
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
8081
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
81-
se = fw_impl.get_sensitivity_evaluator(
82+
se = SensitivityEvaluation(
8283
graph,
8384
mp_config,
8485
representative_data_gen=representative_data_gen,
8586
fw_info=fw_info,
87+
fw_impl=fw_impl,
8688
disable_activation_for_metric=disable_activation_for_metric,
8789
hessian_info_service=hessian_info_service)
8890

@@ -96,10 +98,11 @@ def search_bit_width(graph: Graph,
9698

9799
# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
98100
search_manager = MixedPrecisionSearchManager(graph,
99-
fw_info,
100-
fw_impl,
101-
se,
102-
target_resource_utilization)
101+
fw_info=fw_info,
102+
fw_impl=fw_impl,
103+
sensitivity_evaluator=se,
104+
target_resource_utilization=target_resource_utilization,
105+
mp_config=mp_config)
103106
nodes_bit_cfg = search_manager.search()
104107

105108
graph.skip_validation_check = False

0 commit comments

Comments
 (0)