Skip to content

Commit 98c3f2d

Browse files
irenabirenab
authored andcommitted
fixes
1 parent 62d9eef commit 98c3f2d

18 files changed

Lines changed: 102 additions & 117 deletions

File tree

model_compression_toolkit/core/common/framework_info.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,19 @@ class FrameworkInfo(ABC):
5252
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
5353
5454
Fields:
55-
activation_quantizer_factories: A mapping from QuantizationMethod to a factory function that accepts
56-
activation bitwidth and a dict of quantization params, and returns the corresponding quantization function.
5755
kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
5856
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
5957
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
6058
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
59+
activation_quantizer_factory_mapping: A mapping from QuantizationMethod to a factory function that accepts activation bitwidth and a dict of quantization params, and returns the corresponding quantization function.
6160
"""
6261

63-
activation_quantizer_factories: Dict[QuantizationMethod, Callable[[int, dict], Callable]]
64-
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
6562
kernel_ops_attribute_mapping: Dict[Any, str]
63+
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
6664
out_channel_axis_mapping: Dict[Any, int]
67-
_layer_min_max_mapping: Dict[Any, tuple]
65+
activation_quantizer_factory_mapping: Dict[QuantizationMethod, Callable[[int, dict], Callable]]
6866

67+
_layer_min_max_mapping: Dict[Any, tuple]
6968
_default_channel_mapping = ChannelAxisMapping(None, None)
7069

7170
@classmethod

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,7 @@ def override_fused_node_activation_quantization_candidates(self):
873873
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
874874
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
875875
def update(qc):
876-
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(
877-
fusing_op_quantization_cfg,
878-
qc.activation_quantization_cfg.activation_quantization_fn,
879-
qc.activation_quantization_cfg.activation_quantization_params_fn
880-
)
876+
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
881877
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
882878
node.quantization_cfg.update_all(update)
883879
node.quantization_cfg.remove_duplicates()

model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
2020
CandidateNodeQuantizationConfig
21-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantizer,
21+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantization_fn,
2222
get_weights_quantization_fn)
2323

2424

@@ -107,7 +107,7 @@ def init_activation_quantizers(node_q_cfg: List[CandidateNodeQuantizationConfig]
107107
activation_quantizers = []
108108
for index, qc in enumerate(node_q_cfg):
109109
q_activation = node_q_cfg[index].activation_quantization_cfg
110-
quantizer = get_activation_quantizer(q_activation)
110+
quantizer = get_activation_quantization_fn(q_activation)
111111
activation_quantizers.append(quantizer)
112112

113113
return activation_quantizers

model_compression_toolkit/core/common/network_editors/actions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
from mct_quantizers import QuantizationMethod
2121
from model_compression_toolkit.core.common import Graph
2222
from model_compression_toolkit.logger import Logger
23-
24-
2523
from model_compression_toolkit.core.common.graph.base_node import BaseNode
26-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
27-
get_weights_quantization_fn
24+
2825

2926
_EditRule = namedtuple('EditRule', 'filter action')
3027

model_compression_toolkit/core/common/quantization/quantization_fn_selection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
symmetric_quantizer, uniform_quantizer
2626

2727

28-
def get_activation_quantizer_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
28+
def get_activation_quantization_fn_factory(quantization_method: QuantizationMethod) -> Callable[[int, dict], Callable]:
2929
"""
3030
Get factory for activation quantizer.
3131
@@ -35,10 +35,10 @@ def get_activation_quantizer_factory(quantization_method: QuantizationMethod) ->
3535
Returns:
3636
Factory that accepts activation bitwidth and a dict of quantization params, and returns the quantizer.
3737
"""
38-
return get_fw_info().activation_quantizer_mapping[quantization_method]
38+
return get_fw_info().activation_quantizer_factory_mapping[quantization_method]
3939

4040

41-
def get_activation_quantizer(activation_quantization_cfg: NodeActivationQuantizationConfig) -> Callable:
41+
def get_activation_quantization_fn(activation_quantization_cfg: NodeActivationQuantizationConfig) -> Callable:
4242
"""
4343
Get activation quantizer based on activation quantization configuration.
4444
@@ -48,7 +48,8 @@ def get_activation_quantizer(activation_quantization_cfg: NodeActivationQuantiza
4848
Returns:
4949
Activation quantizer that accepts a tensor and returns a quantized tensor.
5050
"""
51-
quantizer_factory = get_activation_quantizer_factory(activation_quantization_cfg.activation_quantization_method)
51+
quantizer_factory = get_activation_quantization_fn_factory(
52+
activation_quantization_cfg.activation_quantization_method)
5253
quantizer = quantizer_factory(activation_quantization_cfg.activation_n_bits,
5354
activation_quantization_cfg.activation_quantization_params)
5455
return quantizer

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,46 @@
2525
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod
2626

2727

28+
def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
29+
node_prior_info: NodePriorInfo,
30+
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
31+
"""
32+
Compute the activations params for a given node in a graph according to a params function.
33+
34+
Args:
35+
activation_quant_cfg: node's activation quantization configuration.
36+
node_prior_info: Prior info collected for the node that is being quantized.
37+
out_stats_container: Tensor containing output statistics of the node.
38+
39+
Returns:
40+
The computed activation quantization params.
41+
"""
42+
activation_quantization_params_fn = _get_activation_quantization_params_fn(
43+
activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
44+
45+
# Extract and filter histogram data from the statistics container.
46+
bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container)
47+
48+
# Retrieve the minimum and maximum values from the statistics container.
49+
min_value, max_value = out_stats_container.get_min_max_values()
50+
51+
# Determine if the activations should be considered signed.
52+
signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
53+
54+
# Compute and return the activation quantization parameters.
55+
return activation_quantization_params_fn(
56+
bins_values,
57+
bins_counts,
58+
activation_quant_cfg.l_p_value,
59+
activation_quant_cfg.activation_n_bits,
60+
min_value,
61+
max_value,
62+
min_threshold=activation_quant_cfg.min_threshold,
63+
quant_error_method=activation_quant_cfg.activation_error_method,
64+
is_signed=signed
65+
)
66+
67+
2868
def _get_histogram_data(
2969
activation_quant_cfg: NodeActivationQuantizationConfig,
3070
out_stats_container: BaseStatsCollector
@@ -85,46 +125,6 @@ def _determine_signedness(
85125
return np.any(bins_values[:-1][bins_counts > 0] < 0)
86126

87127

88-
def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
89-
node_prior_info: NodePriorInfo,
90-
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
91-
"""
92-
Compute the activations params for a given node in a graph according to a params function.
93-
94-
Args:
95-
activation_quant_cfg: node's activation quantization configuration.
96-
node_prior_info: Prior info collected for the node that is being quantized.
97-
out_stats_container: Tensor containing output statistics of the node.
98-
99-
Returns:
100-
The computed activation quantization params.
101-
"""
102-
activation_quantization_params_fn = _get_activation_quantization_params_fn(
103-
activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
104-
105-
# Extract and filter histogram data from the statistics container.
106-
bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container)
107-
108-
# Retrieve the minimum and maximum values from the statistics container.
109-
min_value, max_value = out_stats_container.get_min_max_values()
110-
111-
# Determine if the activations should be considered signed.
112-
signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
113-
114-
# Compute and return the activation quantization parameters.
115-
return activation_quantization_params_fn(
116-
bins_values,
117-
bins_counts,
118-
activation_quant_cfg.l_p_value,
119-
activation_quant_cfg.activation_n_bits,
120-
min_value,
121-
max_value,
122-
min_threshold=activation_quant_cfg.min_threshold,
123-
quant_error_method=activation_quant_cfg.activation_error_method,
124-
is_signed=signed
125-
)
126-
127-
128128
_activation_quant_params_fns = {
129129
QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_selection_histogram,
130130
QuantizationMethod.SYMMETRIC: qpg.symmetric_selection_histogram,

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
2626
HessianScoresGranularity
2727
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
28-
import get_activations_qparams
28+
import compute_activation_qparams
2929
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
3030
compute_weights_qparams
3131
from model_compression_toolkit.logger import Logger
@@ -130,9 +130,8 @@ def calculate_quantization_params(graph: Graph,
130130

131131
if n.is_activation_quantization_enabled():
132132
# If node's activations should be quantized as well, we compute its activation quantization parameters
133-
activation_params = get_activations_qparams(
134-
activation_quant_cfg=candidate_qc.activation_quantization_cfg,
135-
node_prior_info=n.prior_info,
133+
activation_params = compute_activation_qparams(
134+
activation_quant_cfg=candidate_qc.activation_quantization_cfg, node_prior_info=n.prior_info,
136135
out_stats_container=graph.get_out_stats_collector(n))
137136
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
138137
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)

model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from model_compression_toolkit.core.common.model_collector import ModelCollector
2525
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
2626
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
27-
import get_activations_qparams
27+
import compute_activation_qparams
2828
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
2929
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
3030

@@ -50,12 +50,11 @@ def _collect_and_assign_act_threshold(graph: Graph,
5050
for _data in tqdm(representative_data_gen()):
5151
mi.infer(_data)
5252

53-
for n in list(graph.nodes):
53+
for n in graph.nodes:
5454
if n.is_activation_quantization_enabled():
55-
activation_params = get_activations_qparams(
56-
activation_quant_cfg=n.final_activation_quantization_cfg,
57-
node_prior_info=n.prior_info,
58-
out_stats_container=graph.get_out_stats_collector(n))
55+
activation_params = compute_activation_qparams(activation_quant_cfg=n.final_activation_quantization_cfg,
56+
node_prior_info=n.prior_info,
57+
out_stats_container=graph.get_out_stats_collector(n))
5958
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
6059

6160

model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
from model_compression_toolkit.core import QuantizationConfig
1919
from model_compression_toolkit.core.common import BaseNode, Graph
2020
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantizer
21+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
2322

2423

2524
def get_previous_node_with_activation_quantization(linear_node: BaseNode,
@@ -106,7 +105,7 @@ def compute_activation_bias_correction(graph: Graph,
106105
float_centers = calculate_bin_centers(float_bins)
107106

108107
# Quantize the bin edges and calculate the centers of the quantized bins
109-
activation_quantizer = get_activation_quantizer(prev_node_act_quant_cfg)
108+
activation_quantizer = get_activation_quantization_fn(prev_node_act_quant_cfg)
110109
quant_bins = activation_quantizer(fw_impl.to_tensor(float_bins))
111110
quant_bins = fw_impl.to_numpy(quant_bins)
112111
quant_centers = calculate_bin_centers(quant_bins)

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
2020
ActivationQuantizationMode
2121
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
22-
get_activation_quantizer_factory
22+
get_activation_quantization_fn_factory
2323
from model_compression_toolkit.logger import Logger
2424
from model_compression_toolkit.core.common import Graph, BaseNode
2525
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
2626
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
2727
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
2828
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
29-
import get_activations_qparams
29+
import compute_activation_qparams
3030
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
3131
_mse_error_histogram
3232
from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
@@ -327,7 +327,8 @@ def shift_negative_function(graph: Graph,
327327
'float32') # Change to type float32 to support tensorflow dtypes
328328
for _shift_value in _q_points:
329329
_hist_bins = hist_bins.astype(np.float32) + _shift_value
330-
quantizer_factory = get_activation_quantizer_factory(non_linear_node_cfg_candidate.activation_quantization_method)
330+
quantizer_factory = get_activation_quantization_fn_factory(
331+
non_linear_node_cfg_candidate.activation_quantization_method)
331332
fw_quant_fn = quantizer_factory(non_linear_node_cfg_candidate.activation_n_bits, qparams)
332333
"""
333334
In SNC, when better shifting values are tested for better choice,
@@ -471,11 +472,11 @@ def update(c):
471472
op2d_node=op2d_node)
472473

473474
if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
474-
activation_param = get_activations_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
475-
nodes_prior_info=non_linear_node.prior_info,
476-
out_stats_container=graph.get_out_stats_collector(non_linear_node))
475+
activation_param = compute_activation_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
476+
node_prior_info=non_linear_node.prior_info,
477+
out_stats_container=graph.get_out_stats_collector(non_linear_node))
477478

478-
assert activation_param.get(SIGNED) == False
479+
assert activation_param.get(SIGNED) is False
479480
for candidate_qc in non_linear_node.candidates_quantization_cfg:
480481
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param)
481482

0 commit comments

Comments
 (0)