Skip to content

Commit 011c876

Browse files
irenabirenab
authored andcommitted
remove weights_quantization_fn from WeightsAttrQuantizationConfig
1 parent 6d740b2 commit 011c876

5 files changed

Lines changed: 29 additions & 27 deletions

File tree

model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import numpy as np
1818

19-
from model_compression_toolkit.core.common.framework_info import get_fw_info
2019
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
2120
CandidateNodeQuantizationConfig
22-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantizer
21+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import (get_activation_quantizer,
22+
get_weights_quantization_fn)
2323

2424

2525
def verify_candidates_descending_order(node_q_cfg: List[CandidateNodeQuantizationConfig],
@@ -79,13 +79,13 @@ def init_quantized_weights(node_q_cfg: List[CandidateNodeQuantizationConfig],
7979
quantized_weights = []
8080
for qc in node_q_cfg:
8181
qc_weights_attr = qc.weights_quantization_cfg.get_attr_config(kernel_attr)
82-
q_weight = qc_weights_attr.weights_quantization_fn(float_weights,
83-
qc_weights_attr.weights_n_bits,
84-
True,
85-
qc_weights_attr.weights_quantization_params,
86-
qc_weights_attr.weights_per_channel_threshold,
87-
qc_weights_attr.weights_channels_axis[
88-
0]) # output channel axis
82+
weights_quantization_fn = get_weights_quantization_fn(qc_weights_attr.weights_quantization_method)
83+
q_weight = weights_quantization_fn(float_weights,
84+
qc_weights_attr.weights_n_bits,
85+
True,
86+
qc_weights_attr.weights_quantization_params,
87+
qc_weights_attr.weights_per_channel_threshold,
88+
qc_weights_attr.weights_channels_axis[0]) # output channel axis
8989

9090
quantized_weights.append(fw_tensor_convert_func(q_weight))
9191

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ def __init__(self,
226226
weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None).
227227
"""
228228
# TODO irena remove functions.
229-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
230-
self.weights_quantization_fn = get_weights_quantization_fn(weights_attr_cfg.weights_quantization_method)
231229
self.weights_quantization_params_fn = get_weights_quantization_params_fn(weights_attr_cfg.weights_quantization_method)
232230
self.weights_channels_axis = weights_channels_axis
233231
self.weights_quantization_method = weights_attr_cfg.weights_quantization_method
@@ -316,8 +314,7 @@ def __eq__(self, other: Any) -> bool:
316314
if not isinstance(other, WeightsAttrQuantizationConfig):
317315
return False # pragma: no cover
318316

319-
return self.weights_quantization_fn == other.weights_quantization_fn and \
320-
self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
317+
return self.weights_quantization_params_fn == other.weights_quantization_params_fn and \
321318
self.weights_channels_axis == other.weights_channels_axis and \
322319
self.weights_quantization_method == other.weights_quantization_method and \
323320
self.weights_n_bits == other.weights_n_bits and \
@@ -327,8 +324,7 @@ def __eq__(self, other: Any) -> bool:
327324
self.l_p_value == other.l_p_value
328325

329326
def __hash__(self):
330-
return hash((self.weights_quantization_fn,
331-
self.weights_quantization_params_fn,
327+
return hash((self.weights_quantization_params_fn,
332328
self.weights_channels_axis,
333329
self.weights_error_method,
334330
self.weights_quantization_method,

model_compression_toolkit/core/common/quantization/quantize_node.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
16-
15+
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
1716
from model_compression_toolkit.logger import Logger
1817
from model_compression_toolkit.core.common.graph.base_node import BaseNode
1918
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
@@ -47,11 +46,12 @@ def get_quantized_weights_attr_by_qc(attr_name: str,
4746
output_channels_axis = None
4847

4948
Logger.debug(f'quantizing layer {n.name} attribute {attr_name} with {weights_qc.weights_n_bits} bits')
50-
quantized_kernel = weights_qc.weights_quantization_fn(n.get_weights_by_keys(attr_name),
51-
n_bits=weights_qc.weights_n_bits,
52-
signed=True,
53-
quantization_params=weights_qc.weights_quantization_params,
54-
per_channel=weights_qc.weights_per_channel_threshold,
55-
output_channels_axis=output_channels_axis)
49+
weights_quantization_fn = get_weights_quantization_fn(weights_qc.weights_quantization_method)
50+
quantized_kernel = weights_quantization_fn(n.get_weights_by_keys(attr_name),
51+
n_bits=weights_qc.weights_n_bits,
52+
signed=True,
53+
quantization_params=weights_qc.weights_quantization_params,
54+
per_channel=weights_qc.weights_per_channel_threshold,
55+
output_channels_axis=output_channels_axis)
5656

5757
return quantized_kernel, channels_axis

tests_pytest/keras_tests/unit_tests/core/mixed_precision/test_set_quant_layer_to_bitwidth.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,12 @@ def quant_factory(nbits, *args, **kwargs):
4949
assert np.allclose(x*abits[ind], y)
5050

5151
@pytest.mark.parametrize('ind', [None, 0, 1, 2])
52-
def test_configure_weights(self, ind):
52+
def test_configure_weights(self, ind, mocker):
5353
""" Test correct weights quantizer is set and applied. """
54+
def quant_factory(*args, **kwargs):
55+
return lambda x, nbits, *args: x * nbits
56+
mocker.patch('model_compression_toolkit.core.common.mixed_precision.configurable_quantizer_utils.'
57+
'get_weights_quantization_fn', quant_factory)
5458
inp = keras.layers.Input(shape=(16, 16, 3))
5559
out = keras.layers.Conv2D(8, kernel_size=5)(inp)
5660
model = keras.Model(inp, out)
@@ -63,7 +67,6 @@ def test_configure_weights(self, ind):
6367
for qc in qcs:
6468
attr_cfg = qc.weights_quantization_cfg.get_attr_config(KERNEL)
6569
attr_cfg.weights_channels_axis = (0,)
66-
attr_cfg.weights_quantization_fn = lambda x, nbits, *args: x*nbits
6770
quantizer = ConfigurableWeightsQuantizer(
6871
node_q_cfg=qcs,
6972
float_weights=inner_layer.kernel.numpy(),

tests_pytest/pytorch_tests/unit_tests/core/mixed_precision/test_set_quant_layer_to_bitwidth.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ def quant_factory(nbits, *args, **kwargs):
5252
assert torch.allclose(x*abits[ind], y)
5353

5454
@pytest.mark.parametrize('ind', [None, 0, 1, 2])
55-
def test_configure_weights(self, ind):
55+
def test_configure_weights(self, ind, mocker):
5656
""" Test correct weights quantizer is set and applied. """
57+
def quant_factory(*args, **kwargs):
58+
return lambda x, nbits, *args: x * nbits
59+
mocker.patch('model_compression_toolkit.core.common.mixed_precision.configurable_quantizer_utils.'
60+
'get_weights_quantization_fn', quant_factory)
5761
inner_layer = torch.nn.Conv2d(3, 8, kernel_size=5).to(get_working_device())
5862
orig_weight = inner_layer.weight.clone()
5963
orig_bias = inner_layer.bias.clone()
@@ -63,7 +67,6 @@ def test_configure_weights(self, ind):
6367
for qc in qcs:
6468
attr_cfg = qc.weights_quantization_cfg.get_attr_config(KERNEL)
6569
attr_cfg.weights_channels_axis = (0,)
66-
attr_cfg.weights_quantization_fn = lambda x, nbits, *args: x*nbits
6770

6871
quantizer = ConfigurableWeightsQuantizer(
6972
node_q_cfg=qcs,

0 commit comments

Comments
 (0)