Skip to content

Commit 62d9eef

Browse files
irenabirenab
authored andcommitted
remove weights_quantization_params_fn from WeightsAttrQuantizationConfig
1 parent bf74a58 commit 62d9eef

8 files changed

Lines changed: 77 additions & 188 deletions

File tree

model_compression_toolkit/core/common/network_editors/actions.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424

2525
from model_compression_toolkit.core.common.graph.base_node import BaseNode
26-
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
27-
get_weights_quantization_params_fn
2826
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
2927
get_weights_quantization_fn
3028

@@ -234,7 +232,7 @@ def apply(self, node: BaseNode, graph):
234232

235233
class ChangeFinalWeightsQuantizationMethod(BaseAction):
236234
"""
237-
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer function.
235+
Class ChangeFinalWeightsQuantizationMethod to change a node's weights/activations quantizer method.
238236
"""
239237

240238
def __init__(self, attr_name: str, weights_quantization_method=None):
@@ -260,21 +258,8 @@ def apply(self, node: BaseNode, graph):
260258
"""
261259

262260
if self.weights_quantization_method is not None and node.final_weights_quantization_cfg is not None:
263-
264-
weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
265-
266-
attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
267-
attr_config.override_weights_quantization_params_fn(weights_quantization_params_fn)
268-
269-
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
270-
271-
if weights_quantization_fn is None:
272-
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
273-
274261
attr_config = node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
275-
attr_config.override_weights_quantization_fn(weights_quantization_fn)
276-
node.final_weights_quantization_cfg.get_attr_config(self.attr_name).weights_quantization_method = \
277-
self.weights_quantization_method
262+
attr_config.weights_quantization_method = self.weights_quantization_method
278263

279264

280265
class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
@@ -307,18 +292,7 @@ def apply(self, node: BaseNode, graph: Graph):
307292

308293
if self.weights_quantization_method is not None:
309294
for qc in node.candidates_quantization_cfg:
310-
311-
weights_quantization_params_fn = get_weights_quantization_params_fn(self.weights_quantization_method)
312-
313295
attr_qc = qc.weights_quantization_cfg.get_attr_config(self.attr_name)
314-
attr_qc.override_weights_quantization_params_fn(weights_quantization_params_fn)
315-
316-
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
317-
318-
if weights_quantization_fn is None:
319-
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
320-
321-
attr_qc.override_weights_quantization_fn(weights_quantization_fn)
322296
attr_qc.weights_quantization_method = self.weights_quantization_method
323297

324298

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Callable, Any, List, Dict, TYPE_CHECKING
15+
from typing import Any, List, Dict, TYPE_CHECKING
1616
from enum import Enum, auto
17-
import numpy as np
1817

1918
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
2019
from model_compression_toolkit.logger import Logger
21-
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
22-
get_weights_quantization_params_fn
2320

2421
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
2522
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
@@ -209,8 +206,6 @@ def __init__(self,
209206
weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config.
210207
weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None).
211208
"""
212-
# TODO irena remove functions.
213-
self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method)
214209
self.weights_channels_axis = weights_channels_axis
215210
self.weights_quantization_method = weights_attr_cfg.weights_quantization_method
216211
self.weights_n_bits = weights_attr_cfg.weights_n_bits
@@ -227,26 +222,6 @@ def set_qc(self, qc: QuantizationConfig):
227222
self.weights_error_method = qc.weights_error_method
228223
self.l_p_value = qc.l_p_value
229224

230-
def override_weights_quantization_fn(self, weights_quantization_fn: Callable):
231-
"""
232-
Override weights quantization function for the node.
233-
234-
Args:
235-
weights_quantization_fn: Function for quantazing the weights.
236-
237-
"""
238-
self.weights_quantization_fn = weights_quantization_fn
239-
240-
def override_weights_quantization_params_fn(self, weights_quantization_params_fn: Callable):
241-
"""
242-
Override weights params function for the node.
243-
244-
Args:
245-
weights_quantization_params_fn: Function for calculating the weights params.
246-
247-
"""
248-
self.weights_quantization_params_fn = weights_quantization_params_fn
249-
250225
def set_weights_quantization_param(self,
251226
weights_params: dict):
252227
"""
@@ -260,31 +235,6 @@ def set_weights_quantization_param(self,
260235
for param_name, param_value in weights_params.items():
261236
self.weights_quantization_params[param_name] = param_value
262237

263-
def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshold: float):
264-
"""
265-
Args:
266-
tensor_data: Tensor content as Numpy array.
267-
min_threshold: A minimal threshold to set as quantization parameter.
268-
269-
Returns:
270-
Recalculated weights quantization params from the kernel and channel axis.
271-
272-
"""
273-
assert self.enable_weights_quantization
274-
assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
275-
"Trying to calculate threshold per channel, channel axis in None."
276-
if self.weights_quantization_params_fn is not None:
277-
self.set_weights_quantization_param(
278-
self.weights_quantization_params_fn(tensor_data,
279-
p=self.l_p_value,
280-
n_bits=self.weights_n_bits,
281-
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
282-
channel_axis=self.weights_channels_axis.output, # output channel axis
283-
min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
284-
)
285-
else:
286-
self.set_weights_quantization_param({})
287-
288238
def __eq__(self, other: Any) -> bool:
289239
"""
290240
Compares the object to another object to find if they are equal.
@@ -298,8 +248,7 @@ def __eq__(self, other: Any) -> bool:
298248
if not isinstance(other, WeightsAttrQuantizationConfig):
299249
return False # pragma: no cover
300250

301-
return self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
302-
self.weights_channels_axis == other.weights_channels_axis and \
251+
return self.weights_channels_axis == other.weights_channels_axis and \
303252
self.weights_quantization_method == other.weights_quantization_method and \
304253
self.weights_n_bits == other.weights_n_bits and \
305254
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
@@ -308,8 +257,7 @@ def __eq__(self, other: Any) -> bool:
308257
self.l_p_value == other.l_p_value
309258

310259
def __hash__(self):
311-
return hash((self.weights_quantization_params_fn,
312-
self.weights_channels_axis,
260+
return hash((self.weights_channels_axis,
313261
self.weights_error_method,
314262
self.weights_quantization_method,
315263
self.weights_n_bits,

model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import (
1818
lut_kmeans_tensor, lut_kmeans_histogram)
1919
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import (
20-
symmetric_no_clipping_selection_min_max, symmetric_selection_histogram)
20+
symmetric_no_clipping_selection_min_max, symmetric_selection_histogram, symmetric_selection_tensor)
2121
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import (
22-
uniform_no_clipping_selection_min_max, uniform_selection_histogram)
22+
uniform_no_clipping_selection_min_max, uniform_selection_histogram, uniform_selection_tensor)
2323
from model_compression_toolkit.core.common.quantization.quantization_params_generation.outlier_filter import z_score_filter

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
2828
import get_activations_qparams
2929
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_weights_computation import \
30-
get_weights_qparams
30+
compute_weights_qparams
3131
from model_compression_toolkit.logger import Logger
3232

3333

@@ -119,13 +119,12 @@ def calculate_quantization_params(graph: Graph,
119119
mod_attr_cfg = copy.deepcopy(attr_cfg)
120120
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
121121

122-
weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
123-
candidate_qc.weights_quantization_cfg,
124-
mod_attr_cfg,
125-
output_channels_axis,
126-
node=n,
127-
hessian_info_service=hessian_info_service,
128-
num_hessian_samples=num_hessian_samples)
122+
min_threshold = candidate_qc.weights_quantization_cfg.min_threshold
123+
weights_params, output_channels_axis = compute_weights_qparams(n.get_weights_by_keys(attr),
124+
mod_attr_cfg, output_channels_axis,
125+
min_threshold=min_threshold, node=n,
126+
hessian_info_service=hessian_info_service,
127+
num_hessian_samples=num_hessian_samples)
129128
attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
130129
attr_cfg.set_weights_quantization_param(weights_params)
131130

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,38 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Dict, Any, Tuple
15+
from functools import partial
16+
from typing import Dict, Any, Tuple, Callable, TYPE_CHECKING
1617

1718
import numpy as np
19+
from mct_quantizers import QuantizationMethod
1820

1921
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
2022
from model_compression_toolkit.core.common.hessian import HessianInfoService
21-
from model_compression_toolkit.defaultdict import DefaultDict
22-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23-
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
24-
WeightsAttrQuantizationConfig
23+
from model_compression_toolkit.core.common.quantization.quantization_params_generation import \
24+
power_of_two_selection_tensor, lut_kmeans_tensor, symmetric_selection_tensor, uniform_selection_tensor
2525
from model_compression_toolkit.logger import Logger
2626

27+
if TYPE_CHECKING:
28+
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
2729

28-
def get_weights_qparams(weights_attr_values: np.ndarray,
29-
weights_quant_config: NodeWeightsQuantizationConfig,
30-
attr_quant_config: WeightsAttrQuantizationConfig,
31-
output_channels_axis: int,
32-
node=None,
33-
hessian_info_service: HessianInfoService = None,
34-
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
30+
31+
def compute_weights_qparams(weights_attr_values: np.ndarray,
32+
attr_quant_config: 'WeightsAttrQuantizationConfig',
33+
output_channels_axis: int,
34+
min_threshold: float,
35+
node=None,
36+
hessian_info_service: HessianInfoService = None,
37+
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Tuple[Dict[Any, Any], int]:
3538
"""
3639
Compute thresholds to quantize a kernel according to a NodeWeightsQuantizationConfig
3740
instance.
3841
3942
Args:
4043
weights_attr_values: Weights attribute parameter to compute the quantization thresholds for.
41-
weights_quant_config: Weights quantization configuration to define how the thresholds are computed.
4244
attr_quant_config: A specific weights attribute quantization configuration to get its params.
4345
output_channels_axis: Index of the kernel output channels dimension.
46+
min_threshold: Minimal threshold to use if threshold is too small.
4447
node: The node for which the quantization error is computed (used only with HMSE error method).
4548
hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
4649
num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
@@ -49,22 +52,43 @@ def get_weights_qparams(weights_attr_values: np.ndarray,
4952
A dictionary with the quantization threshold of the kernel.
5053
Selected quantization channel axis.
5154
"""
52-
if attr_quant_config.weights_quantization_params_fn is not None:
53-
weights_params, output_channels_axis = attr_quant_config.weights_quantization_params_fn(
54-
weights_attr_values,
55-
p=attr_quant_config.l_p_value,
56-
n_bits=attr_quant_config.weights_n_bits,
57-
per_channel=attr_quant_config.weights_per_channel_threshold,
58-
channel_axis=output_channels_axis,
59-
min_threshold=weights_quant_config.min_threshold,
60-
quant_error_method=attr_quant_config.weights_error_method,
61-
node=node,
62-
hessian_info_service=hessian_info_service,
63-
num_hessian_samples=num_hessian_samples)
64-
else: # pragma: no cover
65-
Logger.error(f"Requested weights quantization parameters computation for node {node.name} without providing a "
66-
f"weights_quantization_params_fn."
67-
f"Returning an empty dictionary since no quantization parameters were computed.")
68-
weights_params = {}
55+
params_fn = _get_weights_quantization_params_fn(attr_quant_config.weights_quantization_method)
56+
weights_params, output_channels_axis = params_fn(
57+
weights_attr_values,
58+
p=attr_quant_config.l_p_value,
59+
n_bits=attr_quant_config.weights_n_bits,
60+
per_channel=attr_quant_config.weights_per_channel_threshold,
61+
channel_axis=output_channels_axis,
62+
min_threshold=min_threshold,
63+
quant_error_method=attr_quant_config.weights_error_method,
64+
node=node,
65+
hessian_info_service=hessian_info_service,
66+
num_hessian_samples=num_hessian_samples)
6967

7068
return weights_params, output_channels_axis
69+
70+
71+
_weights_quant_params_fns = {
72+
QuantizationMethod.POWER_OF_TWO: power_of_two_selection_tensor,
73+
QuantizationMethod.SYMMETRIC: symmetric_selection_tensor,
74+
QuantizationMethod.UNIFORM: uniform_selection_tensor,
75+
QuantizationMethod.LUT_POT_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=False),
76+
QuantizationMethod.LUT_SYM_QUANTIZER: partial(lut_kmeans_tensor, is_symmetric=True)
77+
}
78+
79+
80+
def _get_weights_quantization_params_fn(weights_quantization_method: QuantizationMethod) -> Callable:
81+
"""
82+
Generate a function for finding weights quantization parameters.
83+
84+
Args:
85+
weights_quantization_method: Which quantization method to use for weights.
86+
Returns:
87+
A function to find the quantization parameters.
88+
89+
"""
90+
params_fn = _weights_quant_params_fns.get(weights_quantization_method)
91+
if not params_fn:
92+
Logger.critical(
93+
f"No parameter function found for the specified quantization method: {weights_quantization_method}") # pragma: no cover
94+
return params_fn

0 commit comments

Comments
 (0)