diff --git a/model_compression_toolkit/core/common/quantization/bit_width_config.py b/model_compression_toolkit/core/common/quantization/bit_width_config.py index 77af61ce1..abd4cc628 100644 --- a/model_compression_toolkit/core/common/quantization/bit_width_config.py +++ b/model_compression_toolkit/core/common/quantization/bit_width_config.py @@ -20,6 +20,8 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.graph.base_node import WeightAttrT +from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR + @dataclass class ManualBitWidthSelection: @@ -221,9 +223,10 @@ def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict: if isinstance(attr_str, str) and isinstance(manual_bit_width_selection.attr, str): if attr_str.find(manual_bit_width_selection.attr) != -1: attr.append(attr_str) - elif isinstance(attr_str, int) and isinstance(manual_bit_width_selection.attr, int): - if attr_str == manual_bit_width_selection.attr: - attr.append(attr_str) + # this is a positional attribute, so it needs to be handled separately. + # Search manual_bit_width_selection's attribute that contain the POS_ATTR string. + elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr: + attr.append(POS_ATTR) if len(attr) == 0: Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.') diff --git a/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py b/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py index c49bef96c..7ca3fdaea 100755 --- a/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py +++ b/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py @@ -17,7 +17,10 @@ import model_compression_toolkit as mct import torch from torch.nn import Conv2d -from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL +from torch import add, sub + +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor +from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL, POS_ATTR from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter from model_compression_toolkit.core import CoreConfig @@ -94,52 +97,142 @@ def generate_tpc_local(default_config, base_config, mixed_precision_cfg_list): return generated_tpc -def get_tpc(kernel_n_bits, bias_n_bits): - base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits) - tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list) - return tpc +def generate_tpc_pos_attr_local(default_config): + default_configuration_options = schema.QuantizationConfigOptions( + quantization_configurations=tuple([default_config])) + + const_config_input16 = default_config.clone_and_edit( + supported_input_activation_n_bits=(8, 16)) + const_config_input16_output16 = const_config_input16.clone_and_edit( + activation_n_bits=16, signedness=schema.Signedness.SIGNED) + + # define a quantization config to quantize the positional weights into 16 bit (for layers where there is a + # positional weight attribute). + positional_weight_16_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=16, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # define a quantization config to quantize the positional weights into 8 bit (for layers where there is a + # positional weight attribute). + positional_weight_8_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + const_config_input16_positional_weight16 = const_config_input16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config}) + + const_config_input16_positional_weight8 = const_config_input16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config}) + const_configuration_options_inout16 = ( + schema.QuantizationConfigOptions(quantization_configurations=tuple([ + const_config_input16, + const_config_input16_positional_weight8, + const_config_input16_positional_weight16]), + base_config=const_config_input16)) + + # define a quantization config to quantize the positional weights into 2 bit (for layers where there is a + # positional weight attribute). + positional_weight_2_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=2, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + const_config_input16_positional_weight2 = const_config_input16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_2_attr_config}) + const_configuration_options_inout_2 = ( + schema.QuantizationConfigOptions(quantization_configurations=tuple([ + const_config_input16, + const_config_input16_positional_weight2]), + base_config=const_config_input16)) + + operator_set = [] + + add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD, qc_options=const_configuration_options_inout16) + sub = schema.OperatorsSet(name=schema.OperatorSetNames.SUB, qc_options=const_configuration_options_inout_2) + operator_set.extend([add, sub]) + + generated_tpc = schema.TargetPlatformCapabilities( + default_qco=default_configuration_options, + operator_set=tuple(operator_set)) + + return generated_tpc + def representative_data_gen(shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1): for _ in range(num_iter): yield [torch.randn(batch_size, *shape)] * num_inputs + def get_float_model(): - class BaseModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) - self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) - self.relu = torch.nn.ReLU() + class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.relu(x) + return x + + return BaseModel() + + +def get_float_model_with_constants(): + class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + a = torch.rand(8) + b = torch.rand(8) + self.a = to_torch_tensor(a) + self.b = to_torch_tensor(b) + + def forward(self, x): + x = torch.add(x, self.a) + x = torch.sub(self.b, x) + return x - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.relu(x) - return x - return BaseModel() + return BaseModel() class TestManualWeightsBitwidthSelectionByLayerType: + def get_float_model(self): + return get_float_model() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list) + return tpc + # (LayerType, bit width, attribute, kernel_n_bits, bias_n_bits) test_input_1 = (NodeTypeFilter(Conv2d), 16, PYTORCH_KERNEL, 16, None) test_input_2 = (NodeTypeFilter(Conv2d), [2], [PYTORCH_KERNEL], 2, None) - + test_expected_1 = ([Conv2d], [16]) test_expected_2 = ([Conv2d], [2]) - + @pytest.mark.parametrize(("inputs", "expected"), [ (test_input_1, test_expected_1), (test_input_2, test_expected_2), ]) - def test_manual_weights_bitwidth_selection(self, inputs, expected): - float_model = get_float_model() + float_model = self.get_float_model() + + target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - core_config = CoreConfig() core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) - + quantized_model, _ = mct.ptq.pytorch_post_training_quantization( in_module=float_model, representative_data_gen=representative_data_gen, @@ -157,12 +250,20 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected): attrs = [attrs] for bitwidth, attr in zip(expected_bitwidths, attrs): - + if layer.weights_quantizers.get(attr) is not None: assert layer.weights_quantizers.get(attr).num_bits == bitwidth class TestManualWeightsBitwidthSelectionByLayerName: + def get_float_model(self): + return get_float_model() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list) + return tpc + # (LayerName, bit width, attribute, kernel_n_bits, bias_n_bits) test_input_1 = (NodeNameFilter("conv1"), 16, PYTORCH_KERNEL, 16, None) test_input_2 = (NodeNameFilter("conv1"), [2], [PYTORCH_KERNEL], 2, None) @@ -171,22 +272,21 @@ class TestManualWeightsBitwidthSelectionByLayerName: test_expected_1 = (["conv1"], [16]) test_expected_2 = (["conv1"], [2]) test_expected_3 = (["conv1", "conv1"], [4, 16]) - + @pytest.mark.parametrize(("inputs", "expected"), [ (test_input_1, test_expected_1), (test_input_2, test_expected_2), (test_input_3, test_expected_3), ]) - def test_manual_weights_bitwidth_selection(self, inputs, expected): - float_model = get_float_model() + float_model = self.get_float_model() + + target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - core_config = CoreConfig() core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) - + quantized_model, _ = mct.ptq.pytorch_post_training_quantization( in_module=float_model, representative_data_gen=representative_data_gen, @@ -207,7 +307,54 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected): else: for attr in attrs: if layer.weights_quantizers.get(attr) is not None: - if attr == PYTORCH_KERNEL: - assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits - elif attr == BIAS: - assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits + if attr == PYTORCH_KERNEL: + assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits + elif attr == BIAS: + assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits + + +class TestManualPositionalAttrWeightsBitwidthSelectionByLayerType(TestManualWeightsBitwidthSelectionByLayerType): + def get_float_model(self): + return get_float_model_with_constants() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + _, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_pos_attr_local(default_config) + return tpc + + # (LayerType, bit width, attribute) + test_input_1 = (NodeTypeFilter(add), 16, POS_ATTR, 8, 8) + test_input_2 = (NodeTypeFilter(sub), [2], [POS_ATTR], 8, 8) + + test_expected_1 = ([add], [16]) + test_expected_2 = ([sub], [2]) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + ]) + def test_manual_weights_bitwidth_selection(self, inputs, expected): + super().test_manual_weights_bitwidth_selection(inputs, expected) + +class TestManualPositionalAttrWeightsBitwidthSelectionByLayerName(TestManualWeightsBitwidthSelectionByLayerName): + def get_float_model(self): + return get_float_model_with_constants() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + _, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_pos_attr_local(default_config) + return tpc + + # (LayerType, bit width, attribute) + test_input_1 = (NodeNameFilter("add"), 8, POS_ATTR, 8, 8) + test_input_2 = (NodeNameFilter("sub"), [2], [POS_ATTR], 8, 8) + + test_expected_1 = (['add'], [16]) + test_expected_2 = (['sub'], [2]) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + ]) + def test_manual_weights_bitwidth_selection(self, inputs, expected): + super().test_manual_weights_bitwidth_selection(inputs, expected) diff --git a/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py b/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py index 0e58915ec..d9a5c489f 100755 --- a/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py +++ b/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py @@ -17,12 +17,13 @@ from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter from model_compression_toolkit.core import BitWidthConfig, QuantizationConfig - from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ set_quantization_configuration_to_graph import torch from torch import nn + +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ AttachTpcToPytorch @@ -34,9 +35,10 @@ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS, POS_ATTR from model_compression_toolkit.target_platform_capabilities.constants import PYTORCH_KERNEL + class TestManualWeightsBitwidthSelection: def get_op_qco(self): # define a default quantization config for all non-specified weights attributes. @@ -118,6 +120,7 @@ def forward(self, x): x = torch.add(x, 2) x = self.relu(x) return x + return BaseModel() def get_test_graph(self, qc): @@ -158,10 +161,10 @@ def get_test_graph(self, qc): @pytest.mark.parametrize( ("inputs", "expected"), [ - (test_input_1, test_expected_1), - (test_input_2, test_expected_2), - (test_input_3, test_expected_3), - ]) + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + (test_input_3, test_expected_3), + ]) def test_manual_weights_bitwidth_selection(self, inputs, expected): for mx_enable in [False, True]: bit_width_config = BitWidthConfig() @@ -197,7 +200,7 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected): ("inputs", "expected"), [ (test_input_4, test_expected_4), (test_input_5, test_expected_5), - ]) + ]) def test_manual_weights_bitwidth_selection_error_add(self, inputs, expected): for mx_enable in [False, True]: bit_width_config = BitWidthConfig() @@ -212,3 +215,120 @@ def test_manual_weights_bitwidth_selection_error_add(self, inputs, expected): ) except Exception as e: assert expected == str(e) + + +class TestManualPositionalAttrWeightsBitwidthSelection(TestManualWeightsBitwidthSelection): + def generate_tpc_local(self, default_config, base_config, mixed_precision_cfg_list): + default_configuration_options = schema.QuantizationConfigOptions( + quantization_configurations=tuple([default_config])) + + const_config_input16 = default_config.clone_and_edit( + supported_input_activation_n_bits=(8, 16)) + + # define a quantization config to quantize the positional weights into 16 bit (for layers where there is a + # positional weight attribute). + positional_weight_16_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=16, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # define a quantization config to quantize the positional weights into 8 bit (for layers where there is a + # positional weight attribute). + positional_weight_8_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + const_config_input16_positional_weight16 = const_config_input16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config}) + + const_config_input16_positional_weight8 = const_config_input16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config}) + const_configuration_options_inout16 = ( + schema.QuantizationConfigOptions(quantization_configurations=tuple([ + const_config_input16, + const_config_input16_positional_weight8, + const_config_input16_positional_weight16]), + base_config=const_config_input16)) + + operator_set = [] + + add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD, qc_options=const_configuration_options_inout16) + operator_set.extend([add]) + + generated_tpc = schema.TargetPlatformCapabilities( + default_qco=default_configuration_options, + operator_set=tuple(operator_set)) + + return generated_tpc + + def get_float_model(self): + class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + const = torch.rand(8) + self.a = to_torch_tensor(const) + + def forward(self, x): + x = torch.add(x, self.a) + return x + + return BaseModel() + + # test case for set_manual_activation_bit_width + """ + Test Items Policy: + - How to specify the target layer: Options(type/name) + - Target attribute information: Options(kernel) + - Bit width variations: Options(2, 4, 16) + """ + test_input_1 = (NodeNameFilter("add"), 16, POS_ATTR) + test_input_2 = (NodeTypeFilter(torch.add), 8, POS_ATTR) + test_input_3 = (NodeNameFilter("add"), 4, POS_ATTR) + test_input_4 = (NodeNameFilter("add"), 2, POS_ATTR) + + test_expected_1 = ({"add": {1: 16}}) + test_expected_2 = ({"add": {1: 8}}) + test_expected_3 = ("Manually selected weights bit-width [[4, 'pos_attr']] is invalid for node " + 'add:add.') + test_expected_4 = ("Manually selected weights bit-width [[2, 'pos_attr']] is invalid for node " + 'add:add.') + + @pytest.mark.parametrize( + ("inputs", "expected"), [ + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + ]) + def test_manual_weights_bitwidth_selection(self, inputs, expected): + bit_width_config = BitWidthConfig() + quantization_config = QuantizationConfig() + graph = self.get_test_graph(quantization_config) + bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) + + updated_graph = set_quantization_configuration_to_graph( + graph, quantization_config, bit_width_config + ) + + for node in updated_graph.nodes: + exp_vals = expected.get(node.name) + assert len(node.candidates_quantization_cfg) == 1 + if exp_vals is None: + continue + + cfg_list = node.candidates_quantization_cfg[0].weights_quantization_cfg.pos_attributes_config_mapping + for vkey in cfg_list: + cfg = cfg_list.get(vkey) + if exp_vals.get(vkey) is not None: + assert cfg.weights_n_bits == exp_vals.get(vkey) + + @pytest.mark.parametrize( + ("inputs", "expected"), [ + (test_input_3, test_expected_3), + (test_input_4, test_expected_4), + ]) + def test_manual_weights_bitwidth_selection_error_add(self, inputs, expected): + super().test_manual_weights_bitwidth_selection_error_add(inputs, expected)