diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 810023322..d09b5436a 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -119,11 +119,16 @@ def filter_node_qco_by_graph(node: BaseNode, _next_nodes.extend(graph.get_next_nodes(n)) next_nodes.append(n) - if len(next_nodes): - next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes] - next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg) + if len(next_nodes) == 0: + return _base_config, _node_qc_options + next_nodes_qc_options = [_node.get_qco(fqc) for _node in next_nodes] + all_next_nodes_supported_input_bitwidth = [max_input_activation_n_bits(op_cfg) for qc_opts in next_nodes_qc_options - for op_cfg in qc_opts.quantization_configurations]) + for op_cfg in qc_opts.quantization_configurations + if op_cfg.enable_activation_quantization or op_cfg.quantization_preserving + ] + if len(all_next_nodes_supported_input_bitwidth): + next_nodes_supported_input_bitwidth = min(all_next_nodes_supported_input_bitwidth) # Filter node's QC options that match next nodes input bit-width. _node_qc_options = [_option for _option in _node_qc_options @@ -205,7 +210,7 @@ def set_quantization_configs_to_node(node: BaseNode, # Preserving the quantization of more than 1 previous node is ambiguous, so disable it. Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.") candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT - elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled(): + elif not prev_nodes[0].is_quantization_preserving() and not prev_nodes[0].is_activation_quantization_enabled(): # Preserving the quantization of an unquantized node isn't possible, so disable it. Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.") candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT diff --git a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py index 56d4e3891..ac2eb9fd2 100644 --- a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +++ b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py @@ -343,6 +343,13 @@ def shift_negative_function(graph: Graph, graph.set_out_stats_collector_to_node(add_node, add_node_stats_collector) graph.shift_stats_collector(add_node, np.array(shift_value)) + set_quantization_configs_to_node(fw_info=fw_info, + node=add_node, + graph=graph, + quant_config=core_config.quantization_config, + fqc=graph.fqc, + mixed_precision_enable=core_config.is_mixed_precision_enabled) + if padding is not None: pad_node = create_pad_node(op2d_node.name, add_node.name, @@ -373,13 +380,6 @@ def shift_negative_function(graph: Graph, op2d_node.input_shape = pad_node.output_shape - set_quantization_configs_to_node(fw_info=fw_info, - node=add_node, - graph=graph, - quant_config=core_config.quantization_config, - fqc=graph.fqc, - mixed_precision_enable=core_config.is_mixed_precision_enabled) - original_non_linear_activation_nbits = non_linear_node_cfg_candidate.activation_n_bits # The non-linear node's output should be float, so we approximate it by using 16bits quantization. for candidate_qc in non_linear_node.candidates_quantization_cfg: diff --git a/tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py b/tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py index 3bbea4d73..ff676016c 100644 --- a/tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py +++ b/tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py @@ -25,37 +25,105 @@ from mct_quantizers import QuantizationMethod +class PreservingNode: + pass + + +class NoActivationQuantNode: + pass + + class TestSetNodeQuantizationConfig: @staticmethod - def _get_op_config(): + def _get_op_config(activation_n_bits, + supported_input_activation_n_bits, + enable_activation_quantization, + quantization_preserving): aqc = AttributeQuantizationConfig() return OpQuantizationConfig(default_weight_attr_config=aqc, attr_weights_configs_mapping={'w': aqc}, activation_quantization_method=QuantizationMethod.POWER_OF_TWO, - activation_n_bits=7, - supported_input_activation_n_bits=7, - enable_activation_quantization=False, - quantization_preserving=True, + activation_n_bits=activation_n_bits, + supported_input_activation_n_bits=supported_input_activation_n_bits, + enable_activation_quantization=enable_activation_quantization, + quantization_preserving=quantization_preserving, signedness=Signedness.AUTO) def test_activation_preserving_with_2_inputs(self, fw_info_mock): """ Tests that . """ n1 = build_node('in1_node') n2 = build_node('in2_node') - n3 = build_node('qp_node') - n4 = build_node('qp2_node') - graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4], + n3 = build_node('qp_node', layer_class=PreservingNode) + n4 = build_node('qp2_node', layer_class=PreservingNode) + qp3 = build_node('qp3_node', layer_class=PreservingNode) + qp4 = build_node('qp4_node', layer_class=PreservingNode) + graph = Graph('g', input_nodes=[n1, n2], nodes=[n3, qp3], output_nodes=[n4, qp4], edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0), - Edge(n3, n4, 0, 0)]) + Edge(n3, n4, 0, 0), + Edge(n1, qp3, 0, 0), Edge(qp3, qp4, 0, 0)]) + q_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7, + "enable_activation_quantization": True, "quantization_preserving": False} + qp_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7, + "enable_activation_quantization": False, "quantization_preserving": True} + _filters = {DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**q_op_config_kwargs)]), + PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**qp_op_config_kwargs)])} + fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters) - fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])}, - layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])}) fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0}, activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0}, get_kernel_op_attributes=lambda x: [None]) - set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc) - set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc) + qc = QuantizationConfig() + for n in graph.get_topo_sorted_nodes(): + set_quantization_configs_to_node(n, graph, qc, fw_info_mock, fqc) assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled() assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled() + assert qp3.is_quantization_preserving() + assert qp4.is_quantization_preserving() + + def test_node_quantization_by_next_nodes(self, fw_info_mock): + """ + Test that node quantization n_bits is unaffected by preserving next node and not-enabled quantization next node. + """ + first_node = build_node('first_node') + preserving_node = build_node('preserving_node', layer_class=PreservingNode) + no_quant_node = build_node('no_enabled_quant_node', layer_class=NoActivationQuantNode) + graph = Graph('g', input_nodes=[first_node], nodes=[preserving_node], output_nodes=[no_quant_node], + edge_list=[Edge(first_node, preserving_node, 0, 0), + Edge(preserving_node, no_quant_node, 0, 0)]) + + first_node_config_kwargs = {"activation_n_bits": 16, + "supported_input_activation_n_bits": [8, 16], + "enable_activation_quantization": True, + "quantization_preserving": False} + + preserving_node_config_kwargs = {"activation_n_bits": 8, + "supported_input_activation_n_bits": [8, 16], + "enable_activation_quantization": False, + "quantization_preserving": True} + + no_quant_node_config_kwargs = {"activation_n_bits": 8, + "supported_input_activation_n_bits": [8], + "enable_activation_quantization": False, + "quantization_preserving": False} + _filters = { + DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**first_node_config_kwargs)]), + PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**preserving_node_config_kwargs)]), + NoActivationQuantNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**no_quant_node_config_kwargs)]), + } + + fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters) + fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0}, + activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0}, + get_kernel_op_attributes=lambda x: [None]) + quantization_config = QuantizationConfig() + set_quantization_configs_to_node(first_node, graph, quantization_config, fw_info_mock, fqc) + set_quantization_configs_to_node(preserving_node, graph, quantization_config, fw_info_mock, fqc) + set_quantization_configs_to_node(no_quant_node, graph, quantization_config, fw_info_mock, fqc) + + assert not first_node.is_quantization_preserving() and first_node.is_activation_quantization_enabled() + assert preserving_node.is_quantization_preserving() and not preserving_node.is_activation_quantization_enabled() + assert not no_quant_node.is_quantization_preserving() and not no_quant_node.is_activation_quantization_enabled() + # assert that first_node n_bits is 16, and isn't affected by its next nodes which supports 8 n_bits only + assert first_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits == 16