1414# ==============================================================================
1515import pytest
1616
17- #import model_compression_toolkit as mct
1817from model_compression_toolkit .core .common .network_editors import NodeTypeFilter , NodeNameFilter
1918from model_compression_toolkit .core import CoreConfig
2019
2625from model_compression_toolkit .target_platform_capabilities .targetplatform2framework .attach2pytorch import \
2726 AttachTpcToPytorch
2827
28+ import model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema as schema
29+ from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import Signedness , \
30+ AttributeQuantizationConfig
31+ from mct_quantizers import QuantizationMethod
32+ from model_compression_toolkit .constants import FLOAT_BITWIDTH
33+
2934from model_compression_toolkit .core .pytorch .default_framework_info import DEFAULT_PYTORCH_INFO
3035from model_compression_toolkit .core .pytorch .pytorch_implementation import PytorchImplementation
3136
3540from model_compression_toolkit .target_platform_capabilities .constants import PYTORCH_KERNEL , BIAS
3641
3742class TestManualWeightsBitwidthSelection :
38- def get_tpc (self ):
39- base_cfg , _ , default_config = get_op_quantization_configs ()
43+ def get_op_qco (self ):
44+ # define a default quantization config for all non-specified weights attributes.
45+ default_weight_attr_config = AttributeQuantizationConfig (
46+ weights_quantization_method = QuantizationMethod .POWER_OF_TWO ,
47+ weights_n_bits = 8 ,
48+ weights_per_channel_threshold = False ,
49+ enable_weights_quantization = False ,
50+ # TODO: this will changed to True once implementing multi-attributes quantization
51+ lut_values_bitwidth = None )
52+
53+ # define a quantization config to quantize the kernel (for layers where there is a kernel attribute).
54+ kernel_base_config = AttributeQuantizationConfig (
55+ weights_quantization_method = QuantizationMethod .SYMMETRIC ,
56+ weights_n_bits = 8 ,
57+ weights_per_channel_threshold = True ,
58+ enable_weights_quantization = True ,
59+ lut_values_bitwidth = None )
60+
61+ # define a quantization config to quantize the bias (for layers where there is a bias attribute).
62+ bias_config = AttributeQuantizationConfig (
63+ weights_quantization_method = QuantizationMethod .POWER_OF_TWO ,
64+ weights_n_bits = FLOAT_BITWIDTH ,
65+ weights_per_channel_threshold = False ,
66+ enable_weights_quantization = False ,
67+ lut_values_bitwidth = None )
68+
69+ base_cfg = schema .OpQuantizationConfig (
70+ default_weight_attr_config = default_weight_attr_config ,
71+ attr_weights_configs_mapping = {KERNEL_ATTR : kernel_base_config , BIAS_ATTR : bias_config },
72+ activation_quantization_method = QuantizationMethod .POWER_OF_TWO ,
73+ activation_n_bits = 8 ,
74+ supported_input_activation_n_bits = 8 ,
75+ enable_activation_quantization = True ,
76+ quantization_preserving = False ,
77+ fixed_scale = None ,
78+ fixed_zero_point = None ,
79+ simd_size = 32 ,
80+ signedness = Signedness .AUTO )
81+
82+ default_config = schema .OpQuantizationConfig (
83+ default_weight_attr_config = default_weight_attr_config ,
84+ attr_weights_configs_mapping = {},
85+ activation_quantization_method = QuantizationMethod .POWER_OF_TWO ,
86+ activation_n_bits = 8 ,
87+ supported_input_activation_n_bits = 8 ,
88+ enable_activation_quantization = True ,
89+ quantization_preserving = False ,
90+ fixed_scale = None ,
91+ fixed_zero_point = None ,
92+ simd_size = 32 ,
93+ signedness = Signedness .AUTO )
4094
4195 mx_cfg_list = [base_cfg ]
4296 for n in [2 , 4 , 16 ]:
4397 mx_cfg_list .append (base_cfg .clone_and_edit (attr_to_edit = {KERNEL_ATTR : {WEIGHTS_N_BITS : n }}))
44- mx_cfg_list .append (base_cfg .clone_and_edit (attr_to_edit = {BIAS_ATTR : {WEIGHTS_N_BITS : n }}))
4598 mx_cfg_list .append (
46- base_cfg .clone_and_edit (attr_to_edit = {KERNEL_ATTR : {WEIGHTS_N_BITS : 4 }, BIAS_ATTR : { WEIGHTS_N_BITS : 16 } })
99+ base_cfg .clone_and_edit (attr_to_edit = {KERNEL_ATTR : {WEIGHTS_N_BITS : 4 }})
47100 )
101+
102+ return base_cfg , mx_cfg_list , default_config
103+
104+ def get_tpc (self ):
105+ base_cfg , mx_cfg_list , default_config = self .get_op_qco ()
106+
48107 tpc = generate_tpc (default_config = default_config , base_config = base_cfg , mixed_precision_cfg_list = mx_cfg_list ,
49- name = 'imx500_tpc_kai ' )
108+ name = 'test_set_node_quantization_config ' )
50109
51110 return tpc
52111
@@ -65,6 +124,7 @@ def __init__(self):
65124 def forward (self , x ):
66125 x = self .conv1 (x )
67126 x = self .conv2 (x )
127+ x = torch .add (x , 2 )
68128 x = self .relu (x )
69129 return x
70130 return BaseModel ()
@@ -90,16 +150,16 @@ def get_test_graph(self, core_config):
90150 """
91151 Test Items Policy:
92152 - How to specify the target layer: Options(type/name)
93- - Target attribute information: Options(kernel/bias )
153+ - Target attribute information: Options(kernel)
94154 - Bit width variations: Options(2, 4, 16)
95155 """
96156 test_input_1 = (NodeNameFilter ("conv1" ), 2 , PYTORCH_KERNEL )
97157 test_input_2 = (NodeTypeFilter (nn .Conv2d ), 16 , PYTORCH_KERNEL )
98- test_input_3 = ([NodeNameFilter ("conv1" ), NodeNameFilter ("conv1 " )], [4 , 16 ], [PYTORCH_KERNEL , BIAS ])
158+ test_input_3 = ([NodeNameFilter ("conv1" ), NodeNameFilter ("conv2 " )], [4 , 8 ], [PYTORCH_KERNEL , PYTORCH_KERNEL ])
99159
100- test_expected_1 = ({"conv1" : {PYTORCH_KERNEL : 2 , BIAS : 32 }, "conv2" : {PYTORCH_KERNEL : 8 , BIAS : 32 }})
101- test_expected_2 = ({"conv1" : {PYTORCH_KERNEL : 16 , BIAS : 32 }, "conv2" : {PYTORCH_KERNEL : 16 , BIAS : 32 }})
102- test_expected_3 = ({"conv1" : {PYTORCH_KERNEL : 4 , BIAS : 16 }, "conv2" : {PYTORCH_KERNEL : 8 , BIAS : 32 }})
160+ test_expected_1 = ({"conv1" : {PYTORCH_KERNEL : 2 }, "conv2" : {PYTORCH_KERNEL : 8 }})
161+ test_expected_2 = ({"conv1" : {PYTORCH_KERNEL : 16 }, "conv2" : {PYTORCH_KERNEL : 16 }})
162+ test_expected_3 = ({"conv1" : {PYTORCH_KERNEL : 4 }, "conv2" : {PYTORCH_KERNEL : 8 }})
103163
104164 @pytest .mark .parametrize (
105165 ("inputs" , "expected" ), [
@@ -123,6 +183,27 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
123183 if exp_vals is None : continue
124184 assert len (node .candidates_quantization_cfg ) == 1
125185
126- for vkey in node .candidates_quantization_cfg [0 ].weights_quantization_cfg .attributes_config_mapping :
127- cfg = node .candidates_quantization_cfg [0 ].weights_quantization_cfg .attributes_config_mapping [vkey ]
128- assert cfg .weights_n_bits == exp_vals [vkey ]
186+ cfg_list = node .candidates_quantization_cfg [0 ].weights_quantization_cfg .attributes_config_mapping
187+ for vkey in cfg_list :
188+ cfg = cfg_list .get (vkey )
189+ if exp_vals .get (vkey ) is not None :
190+ assert cfg .weights_n_bits == exp_vals .get (vkey )
191+
192+ test_input_4 = (NodeNameFilter ("add" ), 2 , PYTORCH_KERNEL )
193+ test_expected_4 = ('The requested attribute weight to change the bit width for add:add does not exist.' )
194+ @pytest .mark .parametrize (
195+ ("inputs" , "expected" ), [
196+ (test_input_4 , test_expected_4 ),
197+ ])
198+ def test_manual_weights_bitwidth_selection_error_add (self , inputs , expected ):
199+ core_config = CoreConfig ()
200+ graph = self .get_test_graph (core_config )
201+
202+ core_config .bit_width_config .set_manual_weights_bit_width (inputs [0 ], inputs [1 ], inputs [2 ])
203+ try :
204+ updated_graph = set_quantization_configuration_to_graph (
205+ graph , core_config .quantization_config , core_config .bit_width_config ,
206+ False , False
207+ )
208+ except Exception as e :
209+ assert expected == str (e )
0 commit comments