Skip to content

Commit 549c6d9

Browse files
authored
Add tests for manual bit width config for postional weights (#1424)
* Add tests for manual bit width config for postional weights
1 parent 0dbb760 commit 549c6d9

File tree

3 files changed

+315
-45
lines changed

3 files changed

+315
-45
lines changed

model_compression_toolkit/core/common/quantization/bit_width_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from model_compression_toolkit.logger import Logger
2121

2222
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
23+
from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR
24+
2325

2426
@dataclass
2527
class ManualBitWidthSelection:
@@ -221,9 +223,10 @@ def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
221223
if isinstance(attr_str, str) and isinstance(manual_bit_width_selection.attr, str):
222224
if attr_str.find(manual_bit_width_selection.attr) != -1:
223225
attr.append(attr_str)
224-
elif isinstance(attr_str, int) and isinstance(manual_bit_width_selection.attr, int):
225-
if attr_str == manual_bit_width_selection.attr:
226-
attr.append(attr_str)
226+
# this is a positional attribute, so it needs to be handled separately.
227+
# Search manual_bit_width_selection's attribute that contain the POS_ATTR string.
228+
elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr:
229+
attr.append(POS_ATTR)
227230
if len(attr) == 0:
228231
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
229232

tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py

Lines changed: 182 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
import model_compression_toolkit as mct
1818
import torch
1919
from torch.nn import Conv2d
20-
from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL
20+
from torch import add, sub
21+
22+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
23+
from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL, POS_ATTR
2124
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS
2225
from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter
2326
from model_compression_toolkit.core import CoreConfig
@@ -94,52 +97,142 @@ def generate_tpc_local(default_config, base_config, mixed_precision_cfg_list):
9497
return generated_tpc
9598

9699

97-
def get_tpc(kernel_n_bits, bias_n_bits):
98-
base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
99-
tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list)
100-
return tpc
100+
def generate_tpc_pos_attr_local(default_config):
101+
default_configuration_options = schema.QuantizationConfigOptions(
102+
quantization_configurations=tuple([default_config]))
103+
104+
const_config_input16 = default_config.clone_and_edit(
105+
supported_input_activation_n_bits=(8, 16))
106+
const_config_input16_output16 = const_config_input16.clone_and_edit(
107+
activation_n_bits=16, signedness=schema.Signedness.SIGNED)
108+
109+
# define a quantization config to quantize the positional weights into 16 bit (for layers where there is a
110+
# positional weight attribute).
111+
positional_weight_16_attr_config = schema.AttributeQuantizationConfig(
112+
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
113+
weights_n_bits=16,
114+
weights_per_channel_threshold=False,
115+
enable_weights_quantization=True,
116+
lut_values_bitwidth=None)
117+
118+
# define a quantization config to quantize the positional weights into 8 bit (for layers where there is a
119+
# positional weight attribute).
120+
positional_weight_8_attr_config = schema.AttributeQuantizationConfig(
121+
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
122+
weights_n_bits=8,
123+
weights_per_channel_threshold=False,
124+
enable_weights_quantization=True,
125+
lut_values_bitwidth=None)
126+
127+
const_config_input16_positional_weight16 = const_config_input16.clone_and_edit(
128+
attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config})
129+
130+
const_config_input16_positional_weight8 = const_config_input16.clone_and_edit(
131+
attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config})
132+
const_configuration_options_inout16 = (
133+
schema.QuantizationConfigOptions(quantization_configurations=tuple([
134+
const_config_input16,
135+
const_config_input16_positional_weight8,
136+
const_config_input16_positional_weight16]),
137+
base_config=const_config_input16))
138+
139+
# define a quantization config to quantize the positional weights into 2 bit (for layers where there is a
140+
# positional weight attribute).
141+
positional_weight_2_attr_config = schema.AttributeQuantizationConfig(
142+
weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
143+
weights_n_bits=2,
144+
weights_per_channel_threshold=False,
145+
enable_weights_quantization=True,
146+
lut_values_bitwidth=None)
147+
148+
const_config_input16_positional_weight2 = const_config_input16.clone_and_edit(
149+
attr_weights_configs_mapping={POS_ATTR: positional_weight_2_attr_config})
150+
const_configuration_options_inout_2 = (
151+
schema.QuantizationConfigOptions(quantization_configurations=tuple([
152+
const_config_input16,
153+
const_config_input16_positional_weight2]),
154+
base_config=const_config_input16))
155+
156+
operator_set = []
157+
158+
add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD, qc_options=const_configuration_options_inout16)
159+
sub = schema.OperatorsSet(name=schema.OperatorSetNames.SUB, qc_options=const_configuration_options_inout_2)
160+
operator_set.extend([add, sub])
161+
162+
generated_tpc = schema.TargetPlatformCapabilities(
163+
default_qco=default_configuration_options,
164+
operator_set=tuple(operator_set))
165+
166+
return generated_tpc
167+
101168

102169
def representative_data_gen(shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1):
103170
for _ in range(num_iter):
104171
yield [torch.randn(batch_size, *shape)] * num_inputs
105172

173+
106174
def get_float_model():
107-
class BaseModel(torch.nn.Module):
108-
def __init__(self):
109-
super().__init__()
110-
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
111-
self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
112-
self.relu = torch.nn.ReLU()
175+
class BaseModel(torch.nn.Module):
176+
def __init__(self):
177+
super().__init__()
178+
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
179+
self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
180+
self.relu = torch.nn.ReLU()
181+
182+
def forward(self, x):
183+
x = self.conv1(x)
184+
x = self.conv2(x)
185+
x = self.relu(x)
186+
return x
187+
188+
return BaseModel()
189+
190+
191+
def get_float_model_with_constants():
192+
class BaseModel(torch.nn.Module):
193+
def __init__(self):
194+
super().__init__()
195+
a = torch.rand(8)
196+
b = torch.rand(8)
197+
self.a = to_torch_tensor(a)
198+
self.b = to_torch_tensor(b)
199+
200+
def forward(self, x):
201+
x = torch.add(x, self.a)
202+
x = torch.sub(self.b, x)
203+
return x
113204

114-
def forward(self, x):
115-
x = self.conv1(x)
116-
x = self.conv2(x)
117-
x = self.relu(x)
118-
return x
119-
return BaseModel()
205+
return BaseModel()
120206

121207

122208
class TestManualWeightsBitwidthSelectionByLayerType:
209+
def get_float_model(self):
210+
return get_float_model()
211+
212+
def get_tpc(self, kernel_n_bits, bias_n_bits):
213+
base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
214+
tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list)
215+
return tpc
216+
123217
# (LayerType, bit width, attribute, kernel_n_bits, bias_n_bits)
124218
test_input_1 = (NodeTypeFilter(Conv2d), 16, PYTORCH_KERNEL, 16, None)
125219
test_input_2 = (NodeTypeFilter(Conv2d), [2], [PYTORCH_KERNEL], 2, None)
126-
220+
127221
test_expected_1 = ([Conv2d], [16])
128222
test_expected_2 = ([Conv2d], [2])
129-
223+
130224
@pytest.mark.parametrize(("inputs", "expected"), [
131225
(test_input_1, test_expected_1),
132226
(test_input_2, test_expected_2),
133227
])
134-
135228
def test_manual_weights_bitwidth_selection(self, inputs, expected):
136-
float_model = get_float_model()
229+
float_model = self.get_float_model()
230+
231+
target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])
137232

138-
target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])
139-
140233
core_config = CoreConfig()
141234
core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2])
142-
235+
143236
quantized_model, _ = mct.ptq.pytorch_post_training_quantization(
144237
in_module=float_model,
145238
representative_data_gen=representative_data_gen,
@@ -157,12 +250,20 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
157250
attrs = [attrs]
158251

159252
for bitwidth, attr in zip(expected_bitwidths, attrs):
160-
253+
161254
if layer.weights_quantizers.get(attr) is not None:
162255
assert layer.weights_quantizers.get(attr).num_bits == bitwidth
163256

164257

165258
class TestManualWeightsBitwidthSelectionByLayerName:
259+
def get_float_model(self):
260+
return get_float_model()
261+
262+
def get_tpc(self, kernel_n_bits, bias_n_bits):
263+
base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
264+
tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list)
265+
return tpc
266+
166267
# (LayerName, bit width, attribute, kernel_n_bits, bias_n_bits)
167268
test_input_1 = (NodeNameFilter("conv1"), 16, PYTORCH_KERNEL, 16, None)
168269
test_input_2 = (NodeNameFilter("conv1"), [2], [PYTORCH_KERNEL], 2, None)
@@ -171,22 +272,21 @@ class TestManualWeightsBitwidthSelectionByLayerName:
171272
test_expected_1 = (["conv1"], [16])
172273
test_expected_2 = (["conv1"], [2])
173274
test_expected_3 = (["conv1", "conv1"], [4, 16])
174-
275+
175276
@pytest.mark.parametrize(("inputs", "expected"), [
176277
(test_input_1, test_expected_1),
177278
(test_input_2, test_expected_2),
178279
(test_input_3, test_expected_3),
179280
])
180-
181281
def test_manual_weights_bitwidth_selection(self, inputs, expected):
182282

183-
float_model = get_float_model()
283+
float_model = self.get_float_model()
284+
285+
target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])
184286

185-
target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4])
186-
187287
core_config = CoreConfig()
188288
core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2])
189-
289+
190290
quantized_model, _ = mct.ptq.pytorch_post_training_quantization(
191291
in_module=float_model,
192292
representative_data_gen=representative_data_gen,
@@ -207,7 +307,54 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected):
207307
else:
208308
for attr in attrs:
209309
if layer.weights_quantizers.get(attr) is not None:
210-
if attr == PYTORCH_KERNEL:
211-
assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits
212-
elif attr == BIAS:
213-
assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits
310+
if attr == PYTORCH_KERNEL:
311+
assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits
312+
elif attr == BIAS:
313+
assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits
314+
315+
316+
class TestManualPositionalAttrWeightsBitwidthSelectionByLayerType(TestManualWeightsBitwidthSelectionByLayerType):
317+
def get_float_model(self):
318+
return get_float_model_with_constants()
319+
320+
def get_tpc(self, kernel_n_bits, bias_n_bits):
321+
_, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
322+
tpc = generate_tpc_pos_attr_local(default_config)
323+
return tpc
324+
325+
# (LayerType, bit width, attribute)
326+
test_input_1 = (NodeTypeFilter(add), 16, POS_ATTR, 8, 8)
327+
test_input_2 = (NodeTypeFilter(sub), [2], [POS_ATTR], 8, 8)
328+
329+
test_expected_1 = ([add], [16])
330+
test_expected_2 = ([sub], [2])
331+
332+
@pytest.mark.parametrize(("inputs", "expected"), [
333+
(test_input_1, test_expected_1),
334+
(test_input_2, test_expected_2),
335+
])
336+
def test_manual_weights_bitwidth_selection(self, inputs, expected):
337+
super().test_manual_weights_bitwidth_selection(inputs, expected)
338+
339+
class TestManualPositionalAttrWeightsBitwidthSelectionByLayerName(TestManualWeightsBitwidthSelectionByLayerName):
340+
def get_float_model(self):
341+
return get_float_model_with_constants()
342+
343+
def get_tpc(self, kernel_n_bits, bias_n_bits):
344+
_, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits)
345+
tpc = generate_tpc_pos_attr_local(default_config)
346+
return tpc
347+
348+
# (LayerType, bit width, attribute)
349+
test_input_1 = (NodeNameFilter("add"), 8, POS_ATTR, 8, 8)
350+
test_input_2 = (NodeNameFilter("sub"), [2], [POS_ATTR], 8, 8)
351+
352+
test_expected_1 = (['add'], [16])
353+
test_expected_2 = (['sub'], [2])
354+
355+
@pytest.mark.parametrize(("inputs", "expected"), [
356+
(test_input_1, test_expected_1),
357+
(test_input_2, test_expected_2),
358+
])
359+
def test_manual_weights_bitwidth_selection(self, inputs, expected):
360+
super().test_manual_weights_bitwidth_selection(inputs, expected)

0 commit comments

Comments
 (0)