1717import model_compression_toolkit as mct
1818import torch
1919from torch .nn import Conv2d
20- from model_compression_toolkit .target_platform_capabilities .constants import BIAS , PYTORCH_KERNEL
20+ from torch import add , sub
21+
22+ from model_compression_toolkit .core .pytorch .utils import to_torch_tensor
23+ from model_compression_toolkit .target_platform_capabilities .constants import BIAS , PYTORCH_KERNEL , POS_ATTR
2124from model_compression_toolkit .target_platform_capabilities .constants import KERNEL_ATTR , BIAS_ATTR , WEIGHTS_N_BITS
2225from model_compression_toolkit .core .common .network_editors import NodeTypeFilter , NodeNameFilter
2326from model_compression_toolkit .core import CoreConfig
@@ -94,52 +97,142 @@ def generate_tpc_local(default_config, base_config, mixed_precision_cfg_list):
9497 return generated_tpc
9598
9699
97- def get_tpc (kernel_n_bits , bias_n_bits ):
98- base_cfg , mx_cfg_list , default_config = get_op_qco (kernel_n_bits , bias_n_bits )
99- tpc = generate_tpc_local (default_config , base_cfg , mx_cfg_list )
100- return tpc
100+ def generate_tpc_pos_attr_local (default_config ):
101+ default_configuration_options = schema .QuantizationConfigOptions (
102+ quantization_configurations = tuple ([default_config ]))
103+
104+ const_config_input16 = default_config .clone_and_edit (
105+ supported_input_activation_n_bits = (8 , 16 ))
106+ const_config_input16_output16 = const_config_input16 .clone_and_edit (
107+ activation_n_bits = 16 , signedness = schema .Signedness .SIGNED )
108+
109+ # define a quantization config to quantize the positional weights into 16 bit (for layers where there is a
110+ # positional weight attribute).
111+ positional_weight_16_attr_config = schema .AttributeQuantizationConfig (
112+ weights_quantization_method = QuantizationMethod .POWER_OF_TWO ,
113+ weights_n_bits = 16 ,
114+ weights_per_channel_threshold = False ,
115+ enable_weights_quantization = True ,
116+ lut_values_bitwidth = None )
117+
118+ # define a quantization config to quantize the positional weights into 8 bit (for layers where there is a
119+ # positional weight attribute).
120+ positional_weight_8_attr_config = schema .AttributeQuantizationConfig (
121+ weights_quantization_method = QuantizationMethod .POWER_OF_TWO ,
122+ weights_n_bits = 8 ,
123+ weights_per_channel_threshold = False ,
124+ enable_weights_quantization = True ,
125+ lut_values_bitwidth = None )
126+
127+ const_config_input16_positional_weight16 = const_config_input16 .clone_and_edit (
128+ attr_weights_configs_mapping = {POS_ATTR : positional_weight_16_attr_config })
129+
130+ const_config_input16_positional_weight8 = const_config_input16 .clone_and_edit (
131+ attr_weights_configs_mapping = {POS_ATTR : positional_weight_8_attr_config })
132+ const_configuration_options_inout16 = (
133+ schema .QuantizationConfigOptions (quantization_configurations = tuple ([
134+ const_config_input16 ,
135+ const_config_input16_positional_weight8 ,
136+ const_config_input16_positional_weight16 ]),
137+ base_config = const_config_input16 ))
138+
139+ # define a quantization config to quantize the positional weights into 2 bit (for layers where there is a
140+ # positional weight attribute).
141+ positional_weight_2_attr_config = schema .AttributeQuantizationConfig (
142+ weights_quantization_method = QuantizationMethod .POWER_OF_TWO ,
143+ weights_n_bits = 2 ,
144+ weights_per_channel_threshold = False ,
145+ enable_weights_quantization = True ,
146+ lut_values_bitwidth = None )
147+
148+ const_config_input16_positional_weight2 = const_config_input16 .clone_and_edit (
149+ attr_weights_configs_mapping = {POS_ATTR : positional_weight_2_attr_config })
150+ const_configuration_options_inout_2 = (
151+ schema .QuantizationConfigOptions (quantization_configurations = tuple ([
152+ const_config_input16 ,
153+ const_config_input16_positional_weight2 ]),
154+ base_config = const_config_input16 ))
155+
156+ operator_set = []
157+
158+ add = schema .OperatorsSet (name = schema .OperatorSetNames .ADD , qc_options = const_configuration_options_inout16 )
159+ sub = schema .OperatorsSet (name = schema .OperatorSetNames .SUB , qc_options = const_configuration_options_inout_2 )
160+ operator_set .extend ([add , sub ])
161+
162+ generated_tpc = schema .TargetPlatformCapabilities (
163+ default_qco = default_configuration_options ,
164+ operator_set = tuple (operator_set ))
165+
166+ return generated_tpc
167+
101168
102169def representative_data_gen (shape = (3 , 8 , 8 ), num_inputs = 1 , batch_size = 2 , num_iter = 1 ):
103170 for _ in range (num_iter ):
104171 yield [torch .randn (batch_size , * shape )] * num_inputs
105172
173+
106174def get_float_model ():
107- class BaseModel (torch .nn .Module ):
108- def __init__ (self ):
109- super ().__init__ ()
110- self .conv1 = torch .nn .Conv2d (in_channels = 3 , out_channels = 3 , kernel_size = 3 )
111- self .conv2 = torch .nn .Conv2d (in_channels = 3 , out_channels = 3 , kernel_size = 3 )
112- self .relu = torch .nn .ReLU ()
175+ class BaseModel (torch .nn .Module ):
176+ def __init__ (self ):
177+ super ().__init__ ()
178+ self .conv1 = torch .nn .Conv2d (in_channels = 3 , out_channels = 3 , kernel_size = 3 )
179+ self .conv2 = torch .nn .Conv2d (in_channels = 3 , out_channels = 3 , kernel_size = 3 )
180+ self .relu = torch .nn .ReLU ()
181+
182+ def forward (self , x ):
183+ x = self .conv1 (x )
184+ x = self .conv2 (x )
185+ x = self .relu (x )
186+ return x
187+
188+ return BaseModel ()
189+
190+
191+ def get_float_model_with_constants ():
192+ class BaseModel (torch .nn .Module ):
193+ def __init__ (self ):
194+ super ().__init__ ()
195+ a = torch .rand (8 )
196+ b = torch .rand (8 )
197+ self .a = to_torch_tensor (a )
198+ self .b = to_torch_tensor (b )
199+
200+ def forward (self , x ):
201+ x = torch .add (x , self .a )
202+ x = torch .sub (self .b , x )
203+ return x
113204
114- def forward (self , x ):
115- x = self .conv1 (x )
116- x = self .conv2 (x )
117- x = self .relu (x )
118- return x
119- return BaseModel ()
205+ return BaseModel ()
120206
121207
122208class TestManualWeightsBitwidthSelectionByLayerType :
209+ def get_float_model (self ):
210+ return get_float_model ()
211+
212+ def get_tpc (self , kernel_n_bits , bias_n_bits ):
213+ base_cfg , mx_cfg_list , default_config = get_op_qco (kernel_n_bits , bias_n_bits )
214+ tpc = generate_tpc_local (default_config , base_cfg , mx_cfg_list )
215+ return tpc
216+
123217 # (LayerType, bit width, attribute, kernel_n_bits, bias_n_bits)
124218 test_input_1 = (NodeTypeFilter (Conv2d ), 16 , PYTORCH_KERNEL , 16 , None )
125219 test_input_2 = (NodeTypeFilter (Conv2d ), [2 ], [PYTORCH_KERNEL ], 2 , None )
126-
220+
127221 test_expected_1 = ([Conv2d ], [16 ])
128222 test_expected_2 = ([Conv2d ], [2 ])
129-
223+
130224 @pytest .mark .parametrize (("inputs" , "expected" ), [
131225 (test_input_1 , test_expected_1 ),
132226 (test_input_2 , test_expected_2 ),
133227 ])
134-
135228 def test_manual_weights_bitwidth_selection (self , inputs , expected ):
136- float_model = get_float_model ()
229+ float_model = self .get_float_model ()
230+
231+ target_platform_cap = self .get_tpc (kernel_n_bits = inputs [3 ], bias_n_bits = inputs [4 ])
137232
138- target_platform_cap = get_tpc (kernel_n_bits = inputs [3 ], bias_n_bits = inputs [4 ])
139-
140233 core_config = CoreConfig ()
141234 core_config .bit_width_config .set_manual_weights_bit_width (inputs [0 ], inputs [1 ], inputs [2 ])
142-
235+
143236 quantized_model , _ = mct .ptq .pytorch_post_training_quantization (
144237 in_module = float_model ,
145238 representative_data_gen = representative_data_gen ,
@@ -157,12 +250,20 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
157250 attrs = [attrs ]
158251
159252 for bitwidth , attr in zip (expected_bitwidths , attrs ):
160-
253+
161254 if layer .weights_quantizers .get (attr ) is not None :
162255 assert layer .weights_quantizers .get (attr ).num_bits == bitwidth
163256
164257
165258class TestManualWeightsBitwidthSelectionByLayerName :
259+ def get_float_model (self ):
260+ return get_float_model ()
261+
262+ def get_tpc (self , kernel_n_bits , bias_n_bits ):
263+ base_cfg , mx_cfg_list , default_config = get_op_qco (kernel_n_bits , bias_n_bits )
264+ tpc = generate_tpc_local (default_config , base_cfg , mx_cfg_list )
265+ return tpc
266+
166267 # (LayerName, bit width, attribute, kernel_n_bits, bias_n_bits)
167268 test_input_1 = (NodeNameFilter ("conv1" ), 16 , PYTORCH_KERNEL , 16 , None )
168269 test_input_2 = (NodeNameFilter ("conv1" ), [2 ], [PYTORCH_KERNEL ], 2 , None )
@@ -171,22 +272,21 @@ class TestManualWeightsBitwidthSelectionByLayerName:
171272 test_expected_1 = (["conv1" ], [16 ])
172273 test_expected_2 = (["conv1" ], [2 ])
173274 test_expected_3 = (["conv1" , "conv1" ], [4 , 16 ])
174-
275+
175276 @pytest .mark .parametrize (("inputs" , "expected" ), [
176277 (test_input_1 , test_expected_1 ),
177278 (test_input_2 , test_expected_2 ),
178279 (test_input_3 , test_expected_3 ),
179280 ])
180-
181281 def test_manual_weights_bitwidth_selection (self , inputs , expected ):
182282
183- float_model = get_float_model ()
283+ float_model = self .get_float_model ()
284+
285+ target_platform_cap = self .get_tpc (kernel_n_bits = inputs [3 ], bias_n_bits = inputs [4 ])
184286
185- target_platform_cap = get_tpc (kernel_n_bits = inputs [3 ], bias_n_bits = inputs [4 ])
186-
187287 core_config = CoreConfig ()
188288 core_config .bit_width_config .set_manual_weights_bit_width (inputs [0 ], inputs [1 ], inputs [2 ])
189-
289+
190290 quantized_model , _ = mct .ptq .pytorch_post_training_quantization (
191291 in_module = float_model ,
192292 representative_data_gen = representative_data_gen ,
@@ -207,7 +307,54 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
207307 else :
208308 for attr in attrs :
209309 if layer .weights_quantizers .get (attr ) is not None :
210- if attr == PYTORCH_KERNEL :
211- assert layer .weights_quantizers .get (attr ).num_bits == kernel_weights_n_bits
212- elif attr == BIAS :
213- assert layer .weights_quantizers .get (attr ).num_bits == bias_weights_n_bits
310+ if attr == PYTORCH_KERNEL :
311+ assert layer .weights_quantizers .get (attr ).num_bits == kernel_weights_n_bits
312+ elif attr == BIAS :
313+ assert layer .weights_quantizers .get (attr ).num_bits == bias_weights_n_bits
314+
315+
316+ class TestManualPositionalAttrWeightsBitwidthSelectionByLayerType (TestManualWeightsBitwidthSelectionByLayerType ):
317+ def get_float_model (self ):
318+ return get_float_model_with_constants ()
319+
320+ def get_tpc (self , kernel_n_bits , bias_n_bits ):
321+ _ , _ , default_config = get_op_qco (kernel_n_bits , bias_n_bits )
322+ tpc = generate_tpc_pos_attr_local (default_config )
323+ return tpc
324+
325+ # (LayerType, bit width, attribute)
326+ test_input_1 = (NodeTypeFilter (add ), 16 , POS_ATTR , 8 , 8 )
327+ test_input_2 = (NodeTypeFilter (sub ), [2 ], [POS_ATTR ], 8 , 8 )
328+
329+ test_expected_1 = ([add ], [16 ])
330+ test_expected_2 = ([sub ], [2 ])
331+
332+ @pytest .mark .parametrize (("inputs" , "expected" ), [
333+ (test_input_1 , test_expected_1 ),
334+ (test_input_2 , test_expected_2 ),
335+ ])
336+ def test_manual_weights_bitwidth_selection (self , inputs , expected ):
337+ super ().test_manual_weights_bitwidth_selection (inputs , expected )
338+
339+ class TestManualPositionalAttrWeightsBitwidthSelectionByLayerName (TestManualWeightsBitwidthSelectionByLayerName ):
340+ def get_float_model (self ):
341+ return get_float_model_with_constants ()
342+
343+ def get_tpc (self , kernel_n_bits , bias_n_bits ):
344+ _ , _ , default_config = get_op_qco (kernel_n_bits , bias_n_bits )
345+ tpc = generate_tpc_pos_attr_local (default_config )
346+ return tpc
347+
348+ # (LayerType, bit width, attribute)
349+ test_input_1 = (NodeNameFilter ("add" ), 8 , POS_ATTR , 8 , 8 )
350+ test_input_2 = (NodeNameFilter ("sub" ), [2 ], [POS_ATTR ], 8 , 8 )
351+
352+ test_expected_1 = (['add' ], [16 ])
353+ test_expected_2 = (['sub' ], [2 ])
354+
355+ @pytest .mark .parametrize (("inputs" , "expected" ), [
356+ (test_input_1 , test_expected_1 ),
357+ (test_input_2 , test_expected_2 ),
358+ ])
359+ def test_manual_weights_bitwidth_selection (self , inputs , expected ):
360+ super ().test_manual_weights_bitwidth_selection (inputs , expected )
0 commit comments