Skip to content

Commit f29cbd5

Browse files
committed
fixed PR comments.
1 parent 18df904 commit f29cbd5

File tree

1 file changed

+95
-14
lines changed

1 file changed

+95
-14
lines changed

tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# ==============================================================================
1515
import pytest
1616

17-
#import model_compression_toolkit as mct
1817
from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter
1918
from model_compression_toolkit.core import CoreConfig
2019

@@ -26,6 +25,12 @@
2625
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
2726
AttachTpcToPytorch
2827

28+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
29+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, \
30+
AttributeQuantizationConfig
31+
from mct_quantizers import QuantizationMethod
32+
from model_compression_toolkit.constants import FLOAT_BITWIDTH
33+
2934
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
3035
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
3136

@@ -35,18 +40,72 @@
3540
from model_compression_toolkit.target_platform_capabilities.constants import PYTORCH_KERNEL, BIAS
3641

3742
class TestManualWeightsBitwidthSelection:
38-
def get_tpc(self):
39-
base_cfg, _, default_config = get_op_quantization_configs()
43+
def get_op_qco(self):
44+
# define a default quantization config for all non-specified weights attributes.
45+
default_weight_attr_config = AttributeQuantizationConfig(
46+
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
47+
weights_n_bits=8,
48+
weights_per_channel_threshold=False,
49+
enable_weights_quantization=False,
50+
# TODO: this will changed to True once implementing multi-attributes quantization
51+
lut_values_bitwidth=None)
52+
53+
# define a quantization config to quantize the kernel (for layers where there is a kernel attribute).
54+
kernel_base_config = AttributeQuantizationConfig(
55+
weights_quantization_method=QuantizationMethod.SYMMETRIC,
56+
weights_n_bits=8,
57+
weights_per_channel_threshold=True,
58+
enable_weights_quantization=True,
59+
lut_values_bitwidth=None)
60+
61+
# define a quantization config to quantize the bias (for layers where there is a bias attribute).
62+
bias_config = AttributeQuantizationConfig(
63+
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
64+
weights_n_bits=FLOAT_BITWIDTH,
65+
weights_per_channel_threshold=False,
66+
enable_weights_quantization=False,
67+
lut_values_bitwidth=None)
68+
69+
base_cfg = schema.OpQuantizationConfig(
70+
default_weight_attr_config=default_weight_attr_config,
71+
attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config},
72+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
73+
activation_n_bits=8,
74+
supported_input_activation_n_bits=8,
75+
enable_activation_quantization=True,
76+
quantization_preserving=False,
77+
fixed_scale=None,
78+
fixed_zero_point=None,
79+
simd_size=32,
80+
signedness=Signedness.AUTO)
81+
82+
default_config = schema.OpQuantizationConfig(
83+
default_weight_attr_config=default_weight_attr_config,
84+
attr_weights_configs_mapping={},
85+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
86+
activation_n_bits=8,
87+
supported_input_activation_n_bits=8,
88+
enable_activation_quantization=True,
89+
quantization_preserving=False,
90+
fixed_scale=None,
91+
fixed_zero_point=None,
92+
simd_size=32,
93+
signedness=Signedness.AUTO)
4094

4195
mx_cfg_list = [base_cfg]
4296
for n in [2, 4, 16]:
4397
mx_cfg_list.append(base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: n}}))
44-
mx_cfg_list.append(base_cfg.clone_and_edit(attr_to_edit={BIAS_ATTR: {WEIGHTS_N_BITS: n}}))
4598
mx_cfg_list.append(
46-
base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}, BIAS_ATTR: {WEIGHTS_N_BITS: 16}})
99+
base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}})
47100
)
101+
102+
return base_cfg, mx_cfg_list, default_config
103+
104+
def get_tpc(self):
105+
base_cfg, mx_cfg_list, default_config = self.get_op_qco()
106+
48107
tpc = generate_tpc(default_config=default_config, base_config=base_cfg, mixed_precision_cfg_list=mx_cfg_list,
49-
name='imx500_tpc_kai')
108+
name='test_set_node_quantization_config')
50109

51110
return tpc
52111

@@ -65,6 +124,7 @@ def __init__(self):
65124
def forward(self, x):
66125
x = self.conv1(x)
67126
x = self.conv2(x)
127+
x = torch.add(x, 2)
68128
x = self.relu(x)
69129
return x
70130
return BaseModel()
@@ -90,16 +150,16 @@ def get_test_graph(self, core_config):
90150
"""
91151
Test Items Policy:
92152
- How to specify the target layer: Options(type/name)
93-
- Target attribute information: Options(kernel/bias)
153+
- Target attribute information: Options(kernel)
94154
- Bit width variations: Options(2, 4, 16)
95155
"""
96156
test_input_1 = (NodeNameFilter("conv1"), 2, PYTORCH_KERNEL)
97157
test_input_2 = (NodeTypeFilter(nn.Conv2d), 16, PYTORCH_KERNEL)
98-
test_input_3 = ([NodeNameFilter("conv1"), NodeNameFilter("conv1")], [4, 16], [PYTORCH_KERNEL, BIAS])
158+
test_input_3 = ([NodeNameFilter("conv1"), NodeNameFilter("conv2")], [4, 8], [PYTORCH_KERNEL, PYTORCH_KERNEL])
99159

100-
test_expected_1 = ({"conv1": {PYTORCH_KERNEL: 2, BIAS: 32}, "conv2": {PYTORCH_KERNEL: 8, BIAS: 32}})
101-
test_expected_2 = ({"conv1": {PYTORCH_KERNEL: 16, BIAS: 32}, "conv2": {PYTORCH_KERNEL: 16, BIAS: 32}})
102-
test_expected_3 = ({"conv1": {PYTORCH_KERNEL: 4, BIAS: 16}, "conv2": {PYTORCH_KERNEL: 8, BIAS: 32}})
160+
test_expected_1 = ({"conv1": {PYTORCH_KERNEL: 2}, "conv2": {PYTORCH_KERNEL: 8}})
161+
test_expected_2 = ({"conv1": {PYTORCH_KERNEL: 16}, "conv2": {PYTORCH_KERNEL: 16}})
162+
test_expected_3 = ({"conv1": {PYTORCH_KERNEL: 4}, "conv2": {PYTORCH_KERNEL: 8}})
103163

104164
@pytest.mark.parametrize(
105165
("inputs", "expected"), [
@@ -123,6 +183,27 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
123183
if exp_vals is None: continue
124184
assert len(node.candidates_quantization_cfg) == 1
125185

126-
for vkey in node.candidates_quantization_cfg[0].weights_quantization_cfg.attributes_config_mapping:
127-
cfg = node.candidates_quantization_cfg[0].weights_quantization_cfg.attributes_config_mapping[vkey]
128-
assert cfg.weights_n_bits == exp_vals[vkey]
186+
cfg_list = node.candidates_quantization_cfg[0].weights_quantization_cfg.attributes_config_mapping
187+
for vkey in cfg_list:
188+
cfg = cfg_list.get(vkey)
189+
if exp_vals.get(vkey) is not None:
190+
assert cfg.weights_n_bits == exp_vals.get(vkey)
191+
192+
test_input_4 = (NodeNameFilter("add"), 2, PYTORCH_KERNEL)
193+
test_expected_4 = ('The requested attribute weight to change the bit width for add:add does not exist.')
194+
@pytest.mark.parametrize(
195+
("inputs", "expected"), [
196+
(test_input_4, test_expected_4),
197+
])
198+
def test_manual_weights_bitwidth_selection_error_add(self, inputs, expected):
199+
core_config = CoreConfig()
200+
graph = self.get_test_graph(core_config)
201+
202+
core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2])
203+
try:
204+
updated_graph = set_quantization_configuration_to_graph(
205+
graph, core_config.quantization_config, core_config.bit_width_config,
206+
False, False
207+
)
208+
except Exception as e:
209+
assert expected == str(e)

0 commit comments

Comments
 (0)