From 241e9820961dfe4337a1fdd97358afd57846d06e Mon Sep 17 00:00:00 2001 From: ariell Date: Wed, 23 Apr 2025 11:15:54 +0300 Subject: [PATCH 1/3] Add tests for manual bit width config for postional weights --- .../common/quantization/bit_width_config.py | 9 +- .../test_weights_manual_selection_bitwidth.py | 218 +++++++++++++++--- .../test_set_node_quantization_config.py | 137 ++++++++++- 3 files changed, 319 insertions(+), 45 deletions(-) diff --git a/model_compression_toolkit/core/common/quantization/bit_width_config.py b/model_compression_toolkit/core/common/quantization/bit_width_config.py index 77af61ce1..abd4cc628 100644 --- a/model_compression_toolkit/core/common/quantization/bit_width_config.py +++ b/model_compression_toolkit/core/common/quantization/bit_width_config.py @@ -20,6 +20,8 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.core.common.graph.base_node import WeightAttrT +from model_compression_toolkit.target_platform_capabilities.constants import POS_ATTR + @dataclass class ManualBitWidthSelection: @@ -221,9 +223,10 @@ def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict: if isinstance(attr_str, str) and isinstance(manual_bit_width_selection.attr, str): if attr_str.find(manual_bit_width_selection.attr) != -1: attr.append(attr_str) - elif isinstance(attr_str, int) and isinstance(manual_bit_width_selection.attr, int): - if attr_str == manual_bit_width_selection.attr: - attr.append(attr_str) + # this is a positional attribute, so it needs to be handled separately. + # Search manual_bit_width_selection's attribute that contain the POS_ATTR string. + elif isinstance(attr_str, int) and POS_ATTR in manual_bit_width_selection.attr: + attr.append(POS_ATTR) if len(attr) == 0: Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.') diff --git a/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py b/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py index c49bef96c..cd8008c14 100755 --- a/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py +++ b/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py @@ -17,7 +17,10 @@ import model_compression_toolkit as mct import torch from torch.nn import Conv2d -from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL +from torch import add, sub + +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor +from model_compression_toolkit.target_platform_capabilities.constants import BIAS, PYTORCH_KERNEL, POS_ATTR from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter from model_compression_toolkit.core import CoreConfig @@ -94,52 +97,143 @@ def generate_tpc_local(default_config, base_config, mixed_precision_cfg_list): return generated_tpc -def get_tpc(kernel_n_bits, bias_n_bits): - base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits) - tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list) - return tpc +def generate_tpc_pos_attr_local(default_config): + default_configuration_options = schema.QuantizationConfigOptions( + quantization_configurations=tuple([default_config])) + + const_config_input16 = default_config.clone_and_edit( + supported_input_activation_n_bits=(8, 16)) + const_config_input16_output16 = const_config_input16.clone_and_edit( + activation_n_bits=16, signedness=schema.Signedness.SIGNED) + + # define a quantization config to quantize the positional weights into 16 bit (for layers where there is a + # positional weight attribute). + positional_weight_16_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=16, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # define a quantization config to quantize the positional weights into 8 bit (for layers where there is a + # positional weight attribute). + positional_weight_8_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + const_config_input16_positional_weight16 = const_config_input16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config}) + + const_config_input16_output16_positional_weight8 = const_config_input16_output16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config}) + const_configuration_options_inout16 = ( + schema.QuantizationConfigOptions(quantization_configurations=tuple([ + const_config_input16_output16, + const_config_input16, + const_config_input16_output16_positional_weight8, + const_config_input16_positional_weight16]), + base_config=const_config_input16)) + + # define a quantization config to quantize the positional weights into 2 bit (for layers where there is a + # positional weight attribute). + positional_weight_2_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=2, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + const_config_input16_output16_positional_weight2 = const_config_input16_output16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_2_attr_config}) + const_configuration_options_inout_2 = ( + schema.QuantizationConfigOptions(quantization_configurations=tuple([ + const_config_input16_output16, + const_config_input16_output16_positional_weight2]), + base_config=const_config_input16_output16)) + + operator_set = [] + + add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD, qc_options=const_configuration_options_inout16) + sub = schema.OperatorsSet(name=schema.OperatorSetNames.SUB, qc_options=const_configuration_options_inout_2) + operator_set.extend([add, sub]) + + generated_tpc = schema.TargetPlatformCapabilities( + default_qco=default_configuration_options, + operator_set=tuple(operator_set)) + + return generated_tpc + def representative_data_gen(shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1): for _ in range(num_iter): yield [torch.randn(batch_size, *shape)] * num_inputs + def get_float_model(): - class BaseModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) - self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) - self.relu = torch.nn.ReLU() + class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.relu(x) + return x + + return BaseModel() + + +def get_float_model_with_constants(): + class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + a = torch.rand(8) + b = torch.rand(8) + self.a = to_torch_tensor(a) + self.b = to_torch_tensor(b) + + def forward(self, x): + x = torch.add(x, self.a) + x = torch.sub(self.b, x) + return x - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.relu(x) - return x - return BaseModel() + return BaseModel() class TestManualWeightsBitwidthSelectionByLayerType: + def get_float_model(self): + return get_float_model() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list) + return tpc + # (LayerType, bit width, attribute, kernel_n_bits, bias_n_bits) test_input_1 = (NodeTypeFilter(Conv2d), 16, PYTORCH_KERNEL, 16, None) test_input_2 = (NodeTypeFilter(Conv2d), [2], [PYTORCH_KERNEL], 2, None) - + test_expected_1 = ([Conv2d], [16]) test_expected_2 = ([Conv2d], [2]) - + @pytest.mark.parametrize(("inputs", "expected"), [ (test_input_1, test_expected_1), (test_input_2, test_expected_2), ]) - def test_manual_weights_bitwidth_selection(self, inputs, expected): - float_model = get_float_model() + float_model = self.get_float_model() + + target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - core_config = CoreConfig() core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) - + quantized_model, _ = mct.ptq.pytorch_post_training_quantization( in_module=float_model, representative_data_gen=representative_data_gen, @@ -157,12 +251,20 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected): attrs = [attrs] for bitwidth, attr in zip(expected_bitwidths, attrs): - + if layer.weights_quantizers.get(attr) is not None: assert layer.weights_quantizers.get(attr).num_bits == bitwidth class TestManualWeightsBitwidthSelectionByLayerName: + def get_float_model(self): + return get_float_model() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + base_cfg, mx_cfg_list, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_local(default_config, base_cfg, mx_cfg_list) + return tpc + # (LayerName, bit width, attribute, kernel_n_bits, bias_n_bits) test_input_1 = (NodeNameFilter("conv1"), 16, PYTORCH_KERNEL, 16, None) test_input_2 = (NodeNameFilter("conv1"), [2], [PYTORCH_KERNEL], 2, None) @@ -171,22 +273,21 @@ class TestManualWeightsBitwidthSelectionByLayerName: test_expected_1 = (["conv1"], [16]) test_expected_2 = (["conv1"], [2]) test_expected_3 = (["conv1", "conv1"], [4, 16]) - + @pytest.mark.parametrize(("inputs", "expected"), [ (test_input_1, test_expected_1), (test_input_2, test_expected_2), (test_input_3, test_expected_3), ]) - def test_manual_weights_bitwidth_selection(self, inputs, expected): - float_model = get_float_model() + float_model = self.get_float_model() + + target_platform_cap = self.get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - target_platform_cap = get_tpc(kernel_n_bits=inputs[3], bias_n_bits=inputs[4]) - core_config = CoreConfig() core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) - + quantized_model, _ = mct.ptq.pytorch_post_training_quantization( in_module=float_model, representative_data_gen=representative_data_gen, @@ -207,7 +308,54 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected): else: for attr in attrs: if layer.weights_quantizers.get(attr) is not None: - if attr == PYTORCH_KERNEL: - assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits - elif attr == BIAS: - assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits + if attr == PYTORCH_KERNEL: + assert layer.weights_quantizers.get(attr).num_bits == kernel_weights_n_bits + elif attr == BIAS: + assert layer.weights_quantizers.get(attr).num_bits == bias_weights_n_bits + + +class TestManualPositionalAttrWeightsBitwidthSelectionByLayerType(TestManualWeightsBitwidthSelectionByLayerType): + def get_float_model(self): + return get_float_model_with_constants() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + _, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_pos_attr_local(default_config) + return tpc + + # (LayerType, bit width, attribute) + test_input_1 = (NodeTypeFilter(add), 16, POS_ATTR, 8, 8) + test_input_2 = (NodeTypeFilter(sub), [2], [POS_ATTR], 8, 8) + + test_expected_1 = ([add], [16]) + test_expected_2 = ([sub], [2]) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + ]) + def test_manual_weights_bitwidth_selection(self, inputs, expected): + super().test_manual_weights_bitwidth_selection(inputs, expected) + +class TestManualPositionalAttrWeightsBitwidthSelectionByLayerName(TestManualWeightsBitwidthSelectionByLayerName): + def get_float_model(self): + return get_float_model_with_constants() + + def get_tpc(self, kernel_n_bits, bias_n_bits): + _, _, default_config = get_op_qco(kernel_n_bits, bias_n_bits) + tpc = generate_tpc_pos_attr_local(default_config) + return tpc + + # (LayerType, bit width, attribute) + test_input_1 = (NodeNameFilter("add"), 8, POS_ATTR, 8, 8) + test_input_2 = (NodeNameFilter("sub"), [2], [POS_ATTR], 8, 8) + + test_expected_1 = (['add'], [16]) + test_expected_2 = (['sub'], [2]) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + ]) + def test_manual_weights_bitwidth_selection(self, inputs, expected): + super().test_manual_weights_bitwidth_selection(inputs, expected) diff --git a/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py b/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py index 0e58915ec..11cb23fdd 100755 --- a/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py +++ b/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py @@ -17,12 +17,13 @@ from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter from model_compression_toolkit.core import BitWidthConfig, QuantizationConfig - from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ set_quantization_configuration_to_graph import torch from torch import nn + +from model_compression_toolkit.core.pytorch.utils import to_torch_tensor from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ AttachTpcToPytorch @@ -34,9 +35,10 @@ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS, POS_ATTR from model_compression_toolkit.target_platform_capabilities.constants import PYTORCH_KERNEL + class TestManualWeightsBitwidthSelection: def get_op_qco(self): # define a default quantization config for all non-specified weights attributes. @@ -118,6 +120,7 @@ def forward(self, x): x = torch.add(x, 2) x = self.relu(x) return x + return BaseModel() def get_test_graph(self, qc): @@ -158,10 +161,10 @@ def get_test_graph(self, qc): @pytest.mark.parametrize( ("inputs", "expected"), [ - (test_input_1, test_expected_1), - (test_input_2, test_expected_2), - (test_input_3, test_expected_3), - ]) + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + (test_input_3, test_expected_3), + ]) def test_manual_weights_bitwidth_selection(self, inputs, expected): for mx_enable in [False, True]: bit_width_config = BitWidthConfig() @@ -197,7 +200,7 @@ def test_manual_weights_bitwidth_selection(self, inputs, expected): ("inputs", "expected"), [ (test_input_4, test_expected_4), (test_input_5, test_expected_5), - ]) + ]) def test_manual_weights_bitwidth_selection_error_add(self, inputs, expected): for mx_enable in [False, True]: bit_width_config = BitWidthConfig() @@ -212,3 +215,123 @@ def test_manual_weights_bitwidth_selection_error_add(self, inputs, expected): ) except Exception as e: assert expected == str(e) + + +class TestManualPositionalAttrWeightsBitwidthSelection(TestManualWeightsBitwidthSelection): + def generate_tpc_local(self, default_config, base_config, mixed_precision_cfg_list): + default_configuration_options = schema.QuantizationConfigOptions( + quantization_configurations=tuple([default_config])) + + const_config_input16 = default_config.clone_and_edit( + supported_input_activation_n_bits=(8, 16)) + const_config_input16_output16 = const_config_input16.clone_and_edit( + activation_n_bits=16, signedness=schema.Signedness.SIGNED) + + # define a quantization config to quantize the positional weights into 16 bit (for layers where there is a + # positional weight attribute). + positional_weight_16_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=16, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + # define a quantization config to quantize the positional weights into 8 bit (for layers where there is a + # positional weight attribute). + positional_weight_8_attr_config = schema.AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=8, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + + const_config_input16_positional_weight16 = const_config_input16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config}) + + const_config_input16_output16_positional_weight8 = const_config_input16_output16.clone_and_edit( + attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config}) + const_configuration_options_inout16 = ( + schema.QuantizationConfigOptions(quantization_configurations=tuple([ + const_config_input16_output16, + const_config_input16, + const_config_input16_output16_positional_weight8, + const_config_input16_positional_weight16]), + base_config=const_config_input16)) + + operator_set = [] + + add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD, qc_options=const_configuration_options_inout16) + operator_set.extend([add]) + + generated_tpc = schema.TargetPlatformCapabilities( + default_qco=default_configuration_options, + operator_set=tuple(operator_set)) + + return generated_tpc + + def get_float_model(self): + class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + const = torch.rand(8) + self.a = to_torch_tensor(const) + + def forward(self, x): + x = torch.add(x, self.a) + return x + + return BaseModel() + + # test case for set_manual_activation_bit_width + """ + Test Items Policy: + - How to specify the target layer: Options(type/name) + - Target attribute information: Options(kernel) + - Bit width variations: Options(2, 4, 16) + """ + test_input_1 = (NodeNameFilter("add"), 16, POS_ATTR) + test_input_2 = (NodeTypeFilter(torch.add), 8, POS_ATTR) + test_input_3 = (NodeNameFilter("add"), 4, POS_ATTR) + test_input_4 = (NodeNameFilter("add"), 2, POS_ATTR) + + test_expected_1 = ({"add": {1: 16}}) + test_expected_2 = ({"add": {1: 8}}) + test_expected_3 = ("Manually selected weights bit-width [[4, 'pos_attr']] is invalid for node " + 'add:add.') + test_expected_4 = ("Manually selected weights bit-width [[2, 'pos_attr']] is invalid for node " + 'add:add.') + + @pytest.mark.parametrize( + ("inputs", "expected"), [ + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + ]) + def test_manual_weights_bitwidth_selection(self, inputs, expected): + bit_width_config = BitWidthConfig() + quantization_config = QuantizationConfig() + graph = self.get_test_graph(quantization_config) + bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) + + updated_graph = set_quantization_configuration_to_graph( + graph, quantization_config, bit_width_config + ) + + for node in updated_graph.nodes: + exp_vals = expected.get(node.name) + assert len(node.candidates_quantization_cfg) == 1 + if exp_vals is None: + continue + + cfg_list = node.candidates_quantization_cfg[0].weights_quantization_cfg.pos_attributes_config_mapping + for vkey in cfg_list: + cfg = cfg_list.get(vkey) + if exp_vals.get(vkey) is not None: + assert cfg.weights_n_bits == exp_vals.get(vkey) + + @pytest.mark.parametrize( + ("inputs", "expected"), [ + (test_input_3, test_expected_3), + (test_input_4, test_expected_4), + ]) + def test_manual_weights_bitwidth_selection_error_add(self, inputs, expected): + super().test_manual_weights_bitwidth_selection_error_add(inputs, expected) From b9aa3445689713adf845011f3739bef9b0722796 Mon Sep 17 00:00:00 2001 From: ariell Date: Wed, 23 Apr 2025 14:11:00 +0300 Subject: [PATCH 2/3] Add tests for manual bit width config for postional weights --- .../test_weights_manual_selection_bitwidth.py | 13 ++++++------- .../test_set_node_quantization_config.py | 7 ++----- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py b/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py index cd8008c14..7ca3fdaea 100755 --- a/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py +++ b/tests_pytest/pytorch_tests/e2e_tests/test_weights_manual_selection_bitwidth.py @@ -127,13 +127,12 @@ def generate_tpc_pos_attr_local(default_config): const_config_input16_positional_weight16 = const_config_input16.clone_and_edit( attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config}) - const_config_input16_output16_positional_weight8 = const_config_input16_output16.clone_and_edit( + const_config_input16_positional_weight8 = const_config_input16.clone_and_edit( attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config}) const_configuration_options_inout16 = ( schema.QuantizationConfigOptions(quantization_configurations=tuple([ - const_config_input16_output16, const_config_input16, - const_config_input16_output16_positional_weight8, + const_config_input16_positional_weight8, const_config_input16_positional_weight16]), base_config=const_config_input16)) @@ -146,13 +145,13 @@ def generate_tpc_pos_attr_local(default_config): enable_weights_quantization=True, lut_values_bitwidth=None) - const_config_input16_output16_positional_weight2 = const_config_input16_output16.clone_and_edit( + const_config_input16_positional_weight2 = const_config_input16.clone_and_edit( attr_weights_configs_mapping={POS_ATTR: positional_weight_2_attr_config}) const_configuration_options_inout_2 = ( schema.QuantizationConfigOptions(quantization_configurations=tuple([ - const_config_input16_output16, - const_config_input16_output16_positional_weight2]), - base_config=const_config_input16_output16)) + const_config_input16, + const_config_input16_positional_weight2]), + base_config=const_config_input16)) operator_set = [] diff --git a/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py b/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py index 11cb23fdd..d9a5c489f 100755 --- a/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py +++ b/tests_pytest/pytorch_tests/integration_tests/core/quantization/test_set_node_quantization_config.py @@ -224,8 +224,6 @@ def generate_tpc_local(self, default_config, base_config, mixed_precision_cfg_li const_config_input16 = default_config.clone_and_edit( supported_input_activation_n_bits=(8, 16)) - const_config_input16_output16 = const_config_input16.clone_and_edit( - activation_n_bits=16, signedness=schema.Signedness.SIGNED) # define a quantization config to quantize the positional weights into 16 bit (for layers where there is a # positional weight attribute). @@ -248,13 +246,12 @@ def generate_tpc_local(self, default_config, base_config, mixed_precision_cfg_li const_config_input16_positional_weight16 = const_config_input16.clone_and_edit( attr_weights_configs_mapping={POS_ATTR: positional_weight_16_attr_config}) - const_config_input16_output16_positional_weight8 = const_config_input16_output16.clone_and_edit( + const_config_input16_positional_weight8 = const_config_input16.clone_and_edit( attr_weights_configs_mapping={POS_ATTR: positional_weight_8_attr_config}) const_configuration_options_inout16 = ( schema.QuantizationConfigOptions(quantization_configurations=tuple([ - const_config_input16_output16, const_config_input16, - const_config_input16_output16_positional_weight8, + const_config_input16_positional_weight8, const_config_input16_positional_weight16]), base_config=const_config_input16)) From 0c31279ecd2dadcb3cf4429a9ce71b0fd40c0b6a Mon Sep 17 00:00:00 2001 From: ariell Date: Tue, 29 Apr 2025 09:34:04 +0300 Subject: [PATCH 3/3] Add onnx exporter output names --- .../fakely_quant_onnx_pytorch_exporter.py | 36 ++++-- ...quant_onnx_pytorch_exporter_output_name.py | 120 ++++++++++++++++++ 2 files changed, 148 insertions(+), 8 deletions(-) create mode 100644 tests_pytest/pytorch_tests/unit_tests/exporter/fakely_quant_onnx_pytorch_exporter_output_name.py diff --git a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py index af6d46cb2..e6318e83a 100644 --- a/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +++ b/model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py @@ -24,11 +24,11 @@ from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter from mct_quantizers import pytorch_quantizers - if FOUND_ONNX: import onnx from mct_quantizers.pytorch.metadata import add_onnx_metadata + class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter): """ Exporter for fakely-quant PyTorch models. @@ -63,7 +63,7 @@ def __init__(self, self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops self._onnx_opset_version = onnx_opset_version - def export(self) -> None: + def export(self, output_names=None) -> None: """ Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model (namely, weights that are in fake-quant format) and fake-quant layers for the activations. @@ -95,6 +95,28 @@ def export(self) -> None: Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}") model_input = to_torch_tensor(next(self.repr_dataset())) + model_output = self.model(*model_input) if isinstance(model_input, (list, tuple)) else self.model( + model_input) + + if output_names is None: + # Determine number of outputs and prepare output_names and dynamic_axes + if isinstance(model_output, (list, tuple)): + output_names = [f"output_{i}" for i in range(len(model_output))] + dynamic_axes = {'input': {0: 'batch_size'}} + dynamic_axes.update({name: {0: 'batch_size'} for name in output_names}) + else: + output_names = ['output'] + dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} + else: + if isinstance(model_output, (list, tuple)): + num_of_outputs = len(model_output) + else: + num_of_outputs = 1 + assert len(output_names) == num_of_outputs, (f"Mismatch between number of requested output names " + f"({output_names}) and model output count " + f"({num_of_outputs}):\n") + dynamic_axes = {'input': {0: 'batch_size'}} + dynamic_axes.update({name: {0: 'batch_size'} for name in output_names}) if hasattr(self.model, 'metadata'): onnx_bytes = BytesIO() @@ -104,9 +126,8 @@ def export(self) -> None: opset_version=self._onnx_opset_version, verbose=False, input_names=['input'], - output_names=['output'], - dynamic_axes={'input': {0: 'batch_size'}, - 'output': {0: 'batch_size'}}) + output_names=output_names, + dynamic_axes=dynamic_axes) onnx_model = onnx.load_from_string(onnx_bytes.getvalue()) onnx_model = add_onnx_metadata(onnx_model, self.model.metadata) onnx.save_model(onnx_model, self.save_model_path) @@ -117,9 +138,8 @@ def export(self) -> None: opset_version=self._onnx_opset_version, verbose=False, input_names=['input'], - output_names=['output'], - dynamic_axes={'input': {0: 'batch_size'}, - 'output': {0: 'batch_size'}}) + output_names=output_names, + dynamic_axes=dynamic_axes) for layer in self.model.children(): # Set disable for reuse for weight quantizers if quantizer was reused diff --git a/tests_pytest/pytorch_tests/unit_tests/exporter/fakely_quant_onnx_pytorch_exporter_output_name.py b/tests_pytest/pytorch_tests/unit_tests/exporter/fakely_quant_onnx_pytorch_exporter_output_name.py new file mode 100644 index 000000000..1f894dc7a --- /dev/null +++ b/tests_pytest/pytorch_tests/unit_tests/exporter/fakely_quant_onnx_pytorch_exporter_output_name.py @@ -0,0 +1,120 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import onnx +import pytest +import torch +import torch.nn as nn + +from model_compression_toolkit.core.pytorch.utils import set_model +from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \ + FakelyQuantONNXPyTorchExporter +from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION +from model_compression_toolkit.exporter.model_wrapper import is_pytorch_layer_exportable + + +class SingleOutputModel(nn.Module): + def __init__(self): + super(SingleOutputModel, self).__init__() + self.linear = nn.Linear(8, 5) + + def forward(self, x): + return self.linear(x) + + +class MultipleOutputModel(nn.Module): + def __init__(self): + super(MultipleOutputModel, self).__init__() + self.linear = nn.Linear(8, 5) + + def forward(self, x): + return self.linear(x), x, x + 2 + + +class TestONNXExporter: + test_input_1 = None + test_expected_1 = ['output'] + + test_input_2 = ['output_2'] + test_expected_2 = ['output_2'] + + test_input_3 = None + test_expected_3 = ['output_0', 'output_1', 'output_2'] + + test_input_4 = ['out', 'out_11', 'out_22'] + test_expected_4 = ['out', 'out_11', 'out_22'] + + test_input_5 = ['out', 'out_11', 'out_22', 'out_33'] + test_expected_5 = ("Mismatch between number of requested output names (['out', 'out_11', 'out_22', 'out_33']) and " + "model output count (3):\n") + + def representative_data_gen(self, shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=1): + for _ in range(num_iter): + yield [torch.randn(batch_size, *shape)] * num_inputs + + def get_exporter(self, model, save_model_path): + return FakelyQuantONNXPyTorchExporter(model, + is_pytorch_layer_exportable, + save_model_path, + self.representative_data_gen, + onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION) + + def export_model(self, model, save_model_path, output_names, expected_output_names): + exporter = self.get_exporter(model, save_model_path) + + exporter.export(output_names) + + assert save_model_path.exists(), "ONNX file was not created" + assert save_model_path.stat().st_size > 0, "ONNX file is empty" + + # Load the ONNX model and check outputs + onnx_model = onnx.load(str(save_model_path)) + outputs = onnx_model.graph.output + + # Check number of outputs + assert len(outputs) == len( + expected_output_names), f"Expected {len(expected_output_names)} output, but found {len(outputs)}" + + found_output_names = [output.name for output in outputs] + assert found_output_names == expected_output_names, ( + f"Expected output name '{expected_output_names}' found {found_output_names}" + ) + + @pytest.mark.parametrize( + ("model", "output_names", "expected_output_names"), [ + (SingleOutputModel(), test_input_1, test_expected_1), + (SingleOutputModel(), test_input_2, test_expected_2), + (MultipleOutputModel(), test_input_3, test_expected_3), + (MultipleOutputModel(), test_input_4, test_expected_4), + ]) + def test_output_model_name(self, tmp_path, model, output_names, expected_output_names): + save_model_path = tmp_path / "model.onnx" + set_model(model) + + self.export_model(model, save_model_path, output_names=output_names, + expected_output_names=expected_output_names) + + @pytest.mark.parametrize( + ("model", "output_names", "expected_output_names"), [ + (MultipleOutputModel(), test_input_5, test_expected_5), + ]) + def test_wrong_number_output_model_name(self, tmp_path, model, output_names, expected_output_names): + save_model_path = tmp_path / "model.onnx" + set_model(model) + + try: + self.export_model(model, save_model_path, output_names=output_names, + expected_output_names=expected_output_names) + except Exception as e: + assert expected_output_names == str(e)