Skip to content

Commit 6d740b2

Browse files
irenabirenab
authored andcommitted
remove activation_quantization_fn from NodeActivationQuantizationCfg
1 parent fad922f commit 6d740b2

17 files changed

Lines changed: 77 additions & 85 deletions

File tree

model_compression_toolkit/core/common/framework_info.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,15 @@ class FrameworkInfo(ABC):
5252
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
5353
5454
Fields:
55-
activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
55+
activation_quantizer_factories: A mapping from QuantizationMethod to a factory function that accepts
56+
activation bitwidth and a dict of quantization params, and returns the corresponding quantization function.
5657
kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
5758
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
5859
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
5960
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
6061
"""
6162

62-
activation_quantizer_mapping: Dict[QuantizationMethod, Callable]
63+
activation_quantizer_factories: Dict[QuantizationMethod, Callable[[int, dict], Callable]]
6364
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
6465
kernel_ops_attribute_mapping: Dict[Any, str]
6566
out_channel_axis_mapping: Dict[Any, int]

model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import numpy as np
1818

19+
from model_compression_toolkit.core.common.framework_info import get_fw_info
1920
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
2021
CandidateNodeQuantizationConfig
22+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantizer
2123

2224

2325
def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
@@ -105,6 +107,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
105107
activation_quantizers = []
106108
for index, qc in enumerate(node_q_cfg):
107109
q_activation = node_q_cfg[index].activation_quantization_cfg
108-
activation_quantizers.append(q_activation.quantize_node_output)
110+
quantizer = get_activation_quantizer(q_activation)
111+
activation_quantizers.append(quantizer)
109112

110113
return activation_quantizers

model_compression_toolkit/core/common/network_editors/actions.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,6 @@ def apply(self, node: BaseNode, graph):
250250
node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
251251
activation_quantization_params_fn)
252252

253-
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(self.activation_quantization_method)
254-
255-
node.final_activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
256253
node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
257254

258255

@@ -283,15 +280,7 @@ def apply(self, node: BaseNode, graph):
283280
for qc in node.candidates_quantization_cfg:
284281
activation_quantization_params_fn = get_activation_quantization_params_fn(
285282
self.activation_quantization_method)
286-
287283
qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
288-
activation_quantization_fn = get_fw_info().activation_quantizer_mapping.get(
289-
self.activation_quantization_method)
290-
291-
if activation_quantization_fn is None:
292-
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
293-
294-
qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
295284
qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
296285

297286

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717
import numpy as np
1818

1919
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
20-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
2120
from model_compression_toolkit.logger import Logger
2221
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
2322
get_weights_quantization_params_fn
2423

2524
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
2625
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
2726
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
28-
AttributeQuantizationConfig, \
29-
OpQuantizationConfig
27+
AttributeQuantizationConfig, OpQuantizationConfig
3028

3129
if TYPE_CHECKING:
3230
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
@@ -85,18 +83,14 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
8583
"""
8684
def __init__(self,
8785
op_cfg: OpQuantizationConfig,
88-
activation_quantization_fn: Callable,
89-
activation_quantization_params_fn: Callable
90-
):
86+
activation_quantization_params_fn: Callable):
9187
"""
9288
9389
Args:
9490
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
95-
activation_quantization_fn: Function to use when quantizing the node's activations.
9691
activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
9792
"""
9893

99-
self.activation_quantization_fn = activation_quantization_fn
10094
self.activation_quantization_params_fn = activation_quantization_params_fn
10195
self.activation_quantization_method = op_cfg.activation_quantization_method
10296
self.activation_n_bits = op_cfg.activation_n_bits
@@ -152,36 +146,6 @@ def quantization_preserving(self):
152146
def fln_quantization(self):
153147
return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
154148

155-
def quantize_node_output(self,
156-
tensors: Any) -> Any:
157-
"""
158-
159-
Args:
160-
tensors: framework tensor/s
161-
162-
Returns:
163-
Framework tensor/s after applying fake quantization.
164-
165-
"""
166-
fake_quant = self.activation_quantization_fn(self.activation_n_bits,
167-
self.activation_quantization_params)
168-
169-
if fake_quant is None:
170-
Logger.critical(
171-
"Layer is intended to be quantized, but the fake_quant function is None.") # pragma: no cover
172-
173-
return fake_quant(tensors)
174-
175-
def set_activation_quantization_fn(self, activation_quantization_fn: Callable):
176-
"""
177-
Sets activation quantization function for the node.
178-
179-
Args:
180-
activation_quantization_fn: Function for quantazing the activations.
181-
182-
"""
183-
self.activation_quantization_fn = activation_quantization_fn
184-
185149
def set_activation_quantization_params_fn(self, activation_quantization_params_fn:Callable):
186150
"""
187151
Sets activation params function for the node.
@@ -218,8 +182,7 @@ def __eq__(self, other: Any) -> bool:
218182
if not isinstance(other, NodeActivationQuantizationConfig):
219183
return False # pragma: no cover
220184

221-
return self.activation_quantization_fn == other.activation_quantization_fn and \
222-
self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
185+
return self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
223186
self.activation_error_method == other.activation_error_method and \
224187
self.activation_quantization_method == other.activation_quantization_method and \
225188
self.activation_n_bits == other.activation_n_bits and \
@@ -234,8 +197,7 @@ def __eq__(self, other: Any) -> bool:
234197
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
235198

236199
def __hash__(self):
237-
return hash((self.activation_quantization_fn,
238-
self.activation_quantization_params_fn,
200+
return hash((self.activation_quantization_params_fn,
239201
self.activation_error_method,
240202
self.activation_quantization_method,
241203
self.activation_n_bits,
@@ -263,6 +225,8 @@ def __init__(self,
263225
weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config.
264226
weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None).
265227
"""
228+
# TODO irena remove functions.
229+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
266230
self.weights_quantization_fn = get_weights_quantization_fn(weights_attr_cfg.weights_quantization_method)
267231
self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method)
268232
self.weights_channels_axis = weights_channels_axis

model_compression_toolkit/core/common/quantization/quantization_fn_selection.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,46 @@
1414
# ==============================================================================
1515

1616
from collections.abc import Callable
17-
from functools import partial
1817

1918
from mct_quantizers import QuantizationMethod
19+
20+
from model_compression_toolkit.core.common.framework_info import get_fw_info
21+
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
2022
from model_compression_toolkit.logger import Logger
2123
from model_compression_toolkit.core.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
2224
from model_compression_toolkit.core.common.quantization.quantizers.uniform_quantizers import power_of_two_quantizer, \
2325
symmetric_quantizer, uniform_quantizer
2426

2527

28+
def get_activation_quantizer_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
29+
"""
30+
Get factory for activation quantizer.
31+
32+
Args:
33+
quantization_method: quantization method for activation.
34+
35+
Returns:
36+
Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
37+
"""
38+
return get_fw_info().activation_quantizer_mapping[quantization_method]
39+
40+
41+
def get_activation_quantizer(activation_quantization_cfg: NodeActivationQuantizationConfig) -> Callable:
42+
"""
43+
Get activation quantizer based on activation quantization configuration.
44+
45+
Args:
46+
activation_quantization_cfg: activation quantization configuration.
47+
48+
Returns:
49+
Activation quantizer that accepts a tensor and returns a quantized tensor.
50+
"""
51+
quantizer_factory = get_activation_quantizer_factory(activation_quantization_cfg.activation_quantization_method)
52+
quantizer = quantizer_factory(activation_quantization_cfg.activation_n_bits,
53+
activation_quantization_cfg.activation_quantization_params)
54+
return quantizer
55+
56+
2657
def get_weights_quantization_fn(weights_quantization_method: QuantizationMethod) -> Callable:
2758
"""
2859
Generate a function for weight quantization.

model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from model_compression_toolkit.core.common import BaseNode, Graph
2020
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
2121
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantizer
2223

2324

2425
def get_previous_node_with_activation_quantization(linear_node: BaseNode,
@@ -105,7 +106,8 @@ def compute_activation_bias_correction(graph: Graph,
105106
float_centers = calculate_bin_centers(float_bins)
106107

107108
# Quantize the bin edges and calculate the centers of the quantized bins
108-
quant_bins = prev_node_act_quant_cfg.quantize_node_output(fw_impl.to_tensor(float_bins))
109+
activation_quantizer = get_activation_quantizer(prev_node_act_quant_cfg)
110+
quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
109111
quant_bins = fw_impl.to_numpy(quant_bins)
110112
quant_centers = calculate_bin_centers(quant_bins)
111113

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
2020
ActivationQuantizationMode
21+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
22+
get_activation_quantizer_factory
2123
from model_compression_toolkit.logger import Logger
2224
from model_compression_toolkit.core.common import Graph, BaseNode
2325
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
@@ -325,7 +327,8 @@ def shift_negative_function(graph: Graph,
325327
'float32') # Change to type float32 to support tensorflow dtypes
326328
for _shift_value in _q_points:
327329
_hist_bins = hist_bins.astype(np.float32) + _shift_value
328-
fw_quant_fn = non_linear_node_cfg_candidate.activation_quantization_fn(non_linear_node_cfg_candidate.activation_n_bits,qparams)
330+
quantizer_factory = get_activation_quantizer_factory(non_linear_node_cfg_candidate.activation_quantization_method)
331+
fw_quant_fn = quantizer_factory(non_linear_node_cfg_candidate.activation_n_bits, qparams)
329332
"""
330333
In SNC, when better shifting values are tested for better choice,
331334
the histogram (which is a numpy object) is quantized using the non-linear node activation

model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from model_compression_toolkit.core import FrameworkInfo
1818
from model_compression_toolkit.core import common
1919
from model_compression_toolkit.core.common import BaseNode
20+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantizer
2021
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
2122
from tensorflow.python.util.object_identity import Reference as TFReference
2223

@@ -56,4 +57,5 @@ def _quantize_node_activations(self,
5657
Output of the node.
5758
5859
"""
59-
return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
60+
activation_quantizer = get_activation_quantizer(node.final_activation_quantization_cfg)
61+
return activation_quantizer(input_tensors)

model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
import torch
1919

20-
from model_compression_toolkit.core import FrameworkInfo
2120
from model_compression_toolkit.core import common
2221
from model_compression_toolkit.core.common import BaseNode
22+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantizer
2323
from model_compression_toolkit.core.common.user_info import UserInformation
2424
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
2525
PytorchModel
@@ -60,7 +60,8 @@ def _quantize_node_activations(self,
6060
if node.is_activation_quantization_enabled():
6161
if isinstance(input_tensors, list):
6262
input_tensors = torch.cat(input_tensors, dim=0)
63-
return node.final_activation_quantization_cfg.quantize_node_output(input_tensors)
63+
activation_quantizer = get_activation_quantizer(node.final_activation_quantization_cfg)
64+
return activation_quantizer(input_tensors)
6465
return input_tensors
6566

6667

model_compression_toolkit/quantization_preparation/load_fqc.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,8 @@ def _create_candidate(weight_channel_axis: ChannelAxisMapping,
209209
"""
210210

211211
# TODO irena: i think we shouldn't inject methods here, it's quantization implementation, not configuration
212-
activation_quantization_fn = get_fw_info().activation_quantizer_mapping[op_cfg.activation_quantization_method]
213212
activation_quantization_params_fn = get_activation_quantization_params_fn(op_cfg.activation_quantization_method)
214213
aqc = NodeActivationQuantizationConfig(op_cfg=op_cfg,
215-
activation_quantization_fn=activation_quantization_fn,
216214
activation_quantization_params_fn=activation_quantization_params_fn)
217215

218216
# TODO: remove this validation and warning once enabling all attributes quantization by default

0 commit comments

Comments
 (0)