|
25 | 25 | from mct_quantizers import QuantizationMethod |
26 | 26 |
|
27 | 27 |
|
| 28 | +class PreservingNode: |
| 29 | + pass |
| 30 | + |
| 31 | + |
| 32 | +class NoActivationQuantNode: |
| 33 | + pass |
| 34 | + |
| 35 | + |
28 | 36 | class TestSetNodeQuantizationConfig: |
29 | 37 |
|
30 | 38 | @staticmethod |
31 | | - def _get_op_config(): |
| 39 | + def _get_op_config(activation_n_bits, |
| 40 | + supported_input_activation_n_bits, |
| 41 | + enable_activation_quantization, |
| 42 | + quantization_preserving): |
32 | 43 | aqc = AttributeQuantizationConfig() |
33 | 44 | return OpQuantizationConfig(default_weight_attr_config=aqc, |
34 | 45 | attr_weights_configs_mapping={'w': aqc}, |
35 | 46 | activation_quantization_method=QuantizationMethod.POWER_OF_TWO, |
36 | | - activation_n_bits=7, |
37 | | - supported_input_activation_n_bits=7, |
38 | | - enable_activation_quantization=False, |
39 | | - quantization_preserving=True, |
| 47 | + activation_n_bits=activation_n_bits, |
| 48 | + supported_input_activation_n_bits=supported_input_activation_n_bits, |
| 49 | + enable_activation_quantization=enable_activation_quantization, |
| 50 | + quantization_preserving=quantization_preserving, |
40 | 51 | signedness=Signedness.AUTO) |
41 | 52 |
|
42 | 53 | def test_activation_preserving_with_2_inputs(self, fw_info_mock): |
43 | 54 | """ Tests that . """ |
44 | 55 | n1 = build_node('in1_node') |
45 | 56 | n2 = build_node('in2_node') |
46 | | - n3 = build_node('qp_node') |
47 | | - n4 = build_node('qp2_node') |
48 | | - graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4], |
| 57 | + n3 = build_node('qp_node', layer_class=PreservingNode) |
| 58 | + n4 = build_node('qp2_node', layer_class=PreservingNode) |
| 59 | + qp3 = build_node('qp3_node', layer_class=PreservingNode) |
| 60 | + qp4 = build_node('qp4_node', layer_class=PreservingNode) |
| 61 | + graph = Graph('g', input_nodes=[n1, n2], nodes=[n3, qp3], output_nodes=[n4, qp4], |
49 | 62 | edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0), |
50 | | - Edge(n3, n4, 0, 0)]) |
| 63 | + Edge(n3, n4, 0, 0), |
| 64 | + Edge(n1, qp3, 0, 0), Edge(qp3, qp4, 0, 0)]) |
| 65 | + q_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7, |
| 66 | + "enable_activation_quantization": True, "quantization_preserving": False} |
| 67 | + qp_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7, |
| 68 | + "enable_activation_quantization": False, "quantization_preserving": True} |
| 69 | + _filters = {DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**q_op_config_kwargs)]), |
| 70 | + PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**qp_op_config_kwargs)])} |
| 71 | + fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters) |
51 | 72 |
|
52 | | - fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])}, |
53 | | - layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])}) |
54 | 73 | fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0}, |
55 | 74 | activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0}, |
56 | 75 | get_kernel_op_attributes=lambda x: [None]) |
57 | | - set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc) |
58 | | - set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc) |
| 76 | + qc = QuantizationConfig() |
| 77 | + for n in graph.get_topo_sorted_nodes(): |
| 78 | + set_quantization_configs_to_node(n, graph, qc, fw_info_mock, fqc) |
59 | 79 | assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled() |
60 | 80 | assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled() |
| 81 | + assert qp3.is_quantization_preserving() |
| 82 | + assert qp4.is_quantization_preserving() |
| 83 | + |
| 84 | + def test_node_quantization_by_next_nodes(self, fw_info_mock): |
| 85 | + """ |
| 86 | + Test that node quantization n_bits is unaffected by preserving next node and not-enabled quantization next node. |
| 87 | + """ |
| 88 | + first_node = build_node('first_node') |
| 89 | + preserving_node = build_node('preserving_node', layer_class=PreservingNode) |
| 90 | + no_quant_node = build_node('no_enabled_quant_node', layer_class=NoActivationQuantNode) |
| 91 | + graph = Graph('g', input_nodes=[first_node], nodes=[preserving_node], output_nodes=[no_quant_node], |
| 92 | + edge_list=[Edge(first_node, preserving_node, 0, 0), |
| 93 | + Edge(preserving_node, no_quant_node, 0, 0)]) |
| 94 | + |
| 95 | + first_node_config_kwargs = {"activation_n_bits": 16, |
| 96 | + "supported_input_activation_n_bits": [8, 16], |
| 97 | + "enable_activation_quantization": True, |
| 98 | + "quantization_preserving": False} |
| 99 | + |
| 100 | + preserving_node_config_kwargs = {"activation_n_bits": 8, |
| 101 | + "supported_input_activation_n_bits": [8, 16], |
| 102 | + "enable_activation_quantization": False, |
| 103 | + "quantization_preserving": True} |
| 104 | + |
| 105 | + no_quant_node_config_kwargs = {"activation_n_bits": 8, |
| 106 | + "supported_input_activation_n_bits": [8], |
| 107 | + "enable_activation_quantization": False, |
| 108 | + "quantization_preserving": False} |
| 109 | + _filters = { |
| 110 | + DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**first_node_config_kwargs)]), |
| 111 | + PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**preserving_node_config_kwargs)]), |
| 112 | + NoActivationQuantNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**no_quant_node_config_kwargs)]), |
| 113 | + } |
| 114 | + |
| 115 | + fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters) |
| 116 | + fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0}, |
| 117 | + activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0}, |
| 118 | + get_kernel_op_attributes=lambda x: [None]) |
| 119 | + quantization_config = QuantizationConfig() |
| 120 | + set_quantization_configs_to_node(first_node, graph, quantization_config, fw_info_mock, fqc) |
| 121 | + set_quantization_configs_to_node(preserving_node, graph, quantization_config, fw_info_mock, fqc) |
| 122 | + set_quantization_configs_to_node(no_quant_node, graph, quantization_config, fw_info_mock, fqc) |
| 123 | + |
| 124 | + assert not first_node.is_quantization_preserving() and first_node.is_activation_quantization_enabled() |
| 125 | + assert preserving_node.is_quantization_preserving() and not preserving_node.is_activation_quantization_enabled() |
| 126 | + assert not no_quant_node.is_quantization_preserving() and not no_quant_node.is_activation_quantization_enabled() |
61 | 127 |
|
| 128 | + # assert that first_node n_bits is 16, and isn't affected by its next nodes which supports 8 n_bits only |
| 129 | + assert first_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits == 16 |
0 commit comments