Skip to content

Commit f90d976

Browse files
authored
Merge branch 'main' into nightly-workflow-reponame
2 parents 7b5a82e + 3644002 commit f90d976

30 files changed

Lines changed: 146 additions & 76 deletions

model_compression_toolkit/core/common/framework_info.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,10 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
17-
from collections.abc import Callable
1816
from enum import Enum
1917
from typing import Dict, Any, Tuple, NamedTuple, Optional
2018
from abc import ABC, abstractmethod
2119

22-
from mct_quantizers import QuantizationMethod
23-
2420

2521
class ChannelAxis(Enum):
2622
"""
@@ -56,13 +52,11 @@ class FrameworkInfo(ABC):
5652
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
5753
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
5854
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
59-
activation_quantizer_factory_mapping: A mapping from QuantizationMethod to a factory function that accepts activation bitwidth and a dict of quantization params, and returns the corresponding quantization function.
6055
"""
6156

6257
kernel_ops_attribute_mapping: Dict[Any, str]
6358
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
6459
out_channel_axis_mapping: Dict[Any, int]
65-
activation_quantizer_factory_mapping: Dict[QuantizationMethod, Callable[[int, dict], Callable]]
6660

6761
_layer_min_max_mapping: Dict[Any, tuple]
6862
_default_channel_mapping = ChannelAxisMapping(None, None)

model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,22 +92,24 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
9292
return quantized_weights
9393

9494

95-
def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]) -> List:
95+
def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig],
96+
get_activation_quantization_fn_factory: Callable) -> List:
9697
"""
9798
Builds a list of quantizers for each of the bitwidth candidates for activation quantization,
9899
to be stored and used during MP search.
99100
100101
Args:
101102
node_q_cfg: Quantization configuration candidates of the node that generated the layer that will
102103
use this quantizer.
104+
get_activation_quantization_fn_factory: activation quantization functions factory.
103105
104106
Returns: a list of activation quantizers - for each bitwidth and layer's attribute to be quantized.
105107
"""
106108

107109
activation_quantizers = []
108110
for index, qc in enumerate(node_q_cfg):
109111
q_activation = node_q_cfg[index].activation_quantization_cfg
110-
quantizer = get_activation_quantization_fn(q_activation)
112+
quantizer = get_activation_quantization_fn(q_activation, get_activation_quantization_fn_factory)
111113
activation_quantizers.append(quantizer)
112114

113115
return activation_quantizers

model_compression_toolkit/core/common/quantization/quantization_fn_selection.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,21 @@
1717

1818
from mct_quantizers import QuantizationMethod
1919

20-
from model_compression_toolkit.core.common.framework_info import get_fw_info
2120
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
2221
from model_compression_toolkit.logger import Logger
2322
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
2423
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
2524
symmetric_quantizer, uniform_quantizer
2625

2726

28-
def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
29-
"""
30-
Get factory for activation quantizer.
31-
32-
Args:
33-
quantization_method: quantization method for activation.
34-
35-
Returns:
36-
Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
37-
"""
38-
return get_fw_info().activation_quantizer_factory_mapping[quantization_method]
39-
40-
41-
def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig) -> Callable:
27+
def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig,
28+
get_activation_quantization_fn_factory: Callable) -> Callable:
4229
"""
4330
Get activation quantizer based on activation quantization configuration.
4431
4532
Args:
4633
activation_quantization_cfg: activation quantization configuration.
34+
get_activation_quantization_fn_factory: activation quantization functions factory.
4735
4836
Returns:
4937
Activation quantizer that accepts a tensor and returns a quantized tensor.

model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def compute_activation_bias_correction(graph: Graph,
6767
fw_impl: FrameworkImplementation,
6868
linear_node: BaseNode,
6969
prev_node: BaseNode,
70-
kernel_size: str) -> Graph:
70+
kernel_size: str,
71+
get_activation_quantization_fn_factory: Callable) -> Graph:
7172
"""
7273
Compute the activation bias correction term, and store it in the final activation
7374
quantization configuration.
@@ -79,6 +80,7 @@ def compute_activation_bias_correction(graph: Graph,
7980
linear_node: Node to compute the activation bias correction for.
8081
prev_node: Node to compute the activation error caused by his activation quantization.
8182
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
83+
get_activation_quantization_fn_factory: activation quantization functions factory.
8284
8385
Returns:
8486
Graph with activation bias correction term for each node.
@@ -105,7 +107,8 @@ def compute_activation_bias_correction(graph: Graph,
105107
float_centers = calculate_bin_centers(float_bins)
106108

107109
# Quantize the bin edges and calculate the centers of the quantized bins
108-
activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg)
110+
activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg,
111+
get_activation_quantization_fn_factory)
109112
quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
110113
quant_bins = fw_impl.to_numpy(quant_bins)
111114
quant_centers = calculate_bin_centers(quant_bins)
@@ -150,7 +153,8 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
150153
quant_config: QuantizationConfig,
151154
fw_impl: FrameworkImplementation,
152155
activation_bias_correction_node_matchers: Callable,
153-
kernel_size: str) -> Graph:
156+
kernel_size: str,
157+
get_activation_quantization_fn_factory: Callable) -> Graph:
154158
"""
155159
Compute the activation bias correction term for the graph.
156160
@@ -160,7 +164,7 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
160164
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
161165
activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
162166
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
163-
167+
get_activation_quantization_fn_factory: activation quantization functions factory.
164168
165169
Returns:
166170
Graph with activation bias correction term for each relevant node.
@@ -176,5 +180,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
176180
fw_impl=fw_impl,
177181
linear_node=n,
178182
prev_node=prev_node,
179-
kernel_size=kernel_size)
183+
kernel_size=kernel_size,
184+
get_activation_quantization_fn_factory=get_activation_quantization_fn_factory)
180185
return graph

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
2020
ActivationQuantizationMode
21-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
22-
get_activation_quantization_fn_factory
2321
from model_compression_toolkit.logger import Logger
2422
from model_compression_toolkit.core.common import Graph, BaseNode
2523
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
@@ -253,6 +251,7 @@ def shift_negative_function(graph: Graph,
253251
padding_str: str,
254252
bias_str: str,
255253
bias_flag_str: str,
254+
get_activation_quantization_fn_factory: Callable,
256255
zero_padding_node: BaseNode = None,
257256
bypass_nodes: List = None,
258257
params_search_quantization_fn: Callable = None
@@ -278,6 +277,7 @@ def shift_negative_function(graph: Graph,
278277
padding_str: The framework specific attribute name of the padding.
279278
bias_str: The framework specific attribute name of the bias.
280279
bias_flag_str: The framework specific attribute name of the bias flag.
280+
get_activation_quantization_fn_factory: activation quantization functions factory.
281281
zero_padding_node: ZeroPadding2D node that may be in the graph before the linear layer.
282282
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
283283
@@ -335,7 +335,7 @@ def shift_negative_function(graph: Graph,
335335
the histogram (which is a numpy object) is quantized using the non-linear node activation
336336
quantization function (to estimate the expected mse comparing to the original histogram).
337337
The quantization function is a framework function, which makes it fail since it
338-
expects a fw tensor. The commmon part of SNC receives an argument which is a callable
338+
expects a fw tensor. The common part of SNC receives an argument which is a callable
339339
that receives two argument and returns one: it gets the fw activation quantization function
340340
and the bins to quantize. The function (of each fw) responsible for doing (if needed) a preprocessing and postprocessing
341341
to the bins which is a numpy object.
@@ -569,6 +569,7 @@ def apply_shift_negative_correction(graph: Graph,
569569
padding_str: str,
570570
bias_str: str,
571571
bias_flag_str: str,
572+
get_activation_quantization_fn_factory: Callable,
572573
params_search_quantization_fn: Callable=None) -> Graph:
573574
"""
574575
Apply the substitution even if the linear node is not immediately after
@@ -590,6 +591,9 @@ def apply_shift_negative_correction(graph: Graph,
590591
padding_str: The framework specific attribute name of the padding.
591592
bias_str: The framework specific attribute name of the bias.
592593
bias_flag_str: The framework specific attribute name of the bias flag.
594+
get_activation_quantization_fn_factory: activation quantization functions factory.
595+
params_search_quantization_fn: Function to quantize np tensor using a framework (tf/torch) quantization method. Needed for better param_search estimating the expected loss.
596+
593597
Returns:
594598
Graph after applying shift negative on selected activations.
595599
"""
@@ -620,6 +624,7 @@ def apply_shift_negative_correction(graph: Graph,
620624
padding_str,
621625
bias_str,
622626
bias_flag_str,
627+
get_activation_quantization_fn_factory,
623628
zero_padding_node=pad_node,
624629
bypass_nodes=bypass_nodes,
625630
params_search_quantization_fn=params_search_quantization_fn)

model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
# ==============================================================================
1515
from typing import List
1616

17-
from model_compression_toolkit.core import FrameworkInfo
1817
from model_compression_toolkit.core import common
1918
from model_compression_toolkit.core.common import BaseNode
2019
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
20+
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
2121
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
2222
from tensorflow.python.util.object_identity import Reference as TFReference
2323

@@ -57,5 +57,6 @@ def _quantize_node_activations(self,
5757
Output of the node.
5858
5959
"""
60-
activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg)
60+
activation_quantizer = get_activation_quantization_fn(node.final_activation_quantization_cfg,
61+
get_activation_quantization_fn_factory)
6162
return activation_quantizer(input_tensors)

model_compression_toolkit/core/keras/default_framework_info.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@
1818
from typing import Tuple, Any, Dict
1919
from functools import wraps
2020

21-
from model_compression_toolkit.core.keras.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
2221
from packaging import version
2322

2423
if version.parse(tf.__version__) >= version.parse("2.13"):
2524
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation
2625
else:
2726
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Softmax, ELU, Activation # pragma: no cover
2827
from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
29-
from mct_quantizers import QuantizationMethod
3028
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD, ACTIVATION
3129
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
3230
KERNEL, DEPTHWISE_KERNEL, GELU
33-
from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization
3431

3532

3633
class KerasInfo(FrameworkInfo):
@@ -69,16 +66,6 @@ class KerasInfo(FrameworkInfo):
6966
Dense: -1,
7067
Conv2DTranspose: -1}
7168

72-
"""
73-
Mapping from a QuantizationMethod to an activation quantizer function.
74-
"""
75-
activation_quantizer_factory_mapping = {
76-
QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
77-
QuantizationMethod.SYMMETRIC: symmetric_quantization,
78-
QuantizationMethod.UNIFORM: uniform_quantization,
79-
QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer
80-
}
81-
8269
"""
8370
Map from an activation function name to its min/max output values (if known).
8471
The values are used for tensor min/max values initialization.

model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
NodeFrameworkAttrMatcher
3535
from model_compression_toolkit.core.common.substitutions.shift_negative_activation import \
3636
apply_shift_negative_correction
37+
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
3738
from model_compression_toolkit.core.keras.constants import KERNEL_SIZE, STRIDES, ACTIVATION, SWISH, \
3839
SELU, GELU, FUNCTION, ADD, PAD
3940
from model_compression_toolkit.core.keras.constants import NEGATIVE_SLOPE, PADDING, PAD_SAME, PAD_VALID, BIAS, USE_BIAS
@@ -252,5 +253,6 @@ def keras_apply_shift_negative_correction(graph: Graph,
252253
is_padding_node_and_node_has_padding,
253254
PADDING,
254255
BIAS,
255-
USE_BIAS
256+
USE_BIAS,
257+
get_activation_quantization_fn_factory
256258
)

model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
verify_candidates_descending_order, init_activation_quantizers
2424
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
2525
CandidateNodeQuantizationConfig
26+
from model_compression_toolkit.core.keras.quantization.activation_quantization_fn_factory import get_activation_quantization_fn_factory
2627
from model_compression_toolkit.logger import Logger
2728

2829
import tensorflow as tf
@@ -67,7 +68,7 @@ def __init__(self,
6768
if qc.activation_quantization_cfg.quant_mode != node_q_cfg[0].activation_quantization_cfg.quant_mode:
6869
Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
6970

70-
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
71+
self.activation_quantizers = init_activation_quantizers(self.node_q_cfg, get_activation_quantization_fn_factory)
7172
self.active_quantization_config_index = max_candidate_idx # initialize with first config as default
7273

7374
def set_active_activation_quantizer(self, index: Optional[int]):

model_compression_toolkit/core/keras/quantizer/__init__.py renamed to model_compression_toolkit/core/keras/quantization/__init__.py

File renamed without changes.

0 commit comments

Comments
 (0)