|
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================== |
15 | 15 | from types import MethodType |
16 | | -from typing import Iterable, Union |
17 | 16 | from unittest.mock import Mock |
18 | 17 |
|
19 | 18 | import numpy as np |
20 | 19 | import pytest |
21 | | -from mct_quantizers import QuantizationMethod |
22 | 20 |
|
23 | 21 | from model_compression_toolkit.constants import FLOAT_BITWIDTH |
24 | | -from model_compression_toolkit.core import QuantizationConfig, ResourceUtilization |
25 | | -from model_compression_toolkit.core.common import Graph, BaseNode |
| 22 | +from model_compression_toolkit.core import ResourceUtilization |
| 23 | +from model_compression_toolkit.core.common import Graph |
26 | 24 | from model_compression_toolkit.core.common.graph.edge import Edge |
27 | 25 | from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut |
28 | 26 | from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut |
|
34 | 32 | RUTarget |
35 | 33 | from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ |
36 | 34 | Utilization, ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode |
37 | | -from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ |
38 | | - CandidateNodeQuantizationConfig |
39 | | -from model_compression_toolkit.core.common.quantization.node_quantization_config import \ |
40 | | - NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig |
41 | | -from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \ |
42 | | - AttributeQuantizationConfig, Signedness |
43 | | - |
44 | | - |
45 | | -def full_attr_name(canonical_name: Union[str, dict, Iterable]): |
46 | | - """ Convert canonical attr (such as 'kernel') into a full name originated from the layer (e.g. 'conv2d_1/kernel:0') |
47 | | - We just need the names to differ from canonical to make sure we call the correct apis. We use the same |
48 | | - template for simplicity, so we don't have to explicitly synchronize names between node and weight configs.""" |
49 | | - convert = lambda name: f'{name[0]}/{name}/{name[-1]}' if isinstance(name, str) else name |
50 | | - if isinstance(canonical_name, str): |
51 | | - return convert(canonical_name) |
52 | | - assert isinstance(canonical_name, (list, tuple, set)) |
53 | | - return canonical_name.__class__([convert(name) for name in canonical_name]) |
54 | | - |
55 | | - |
56 | | -def build_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, ())): |
57 | | - """ Build quantization config for tests. |
58 | | - w_attr contains {canonical name: (nbits, q_enabled)} |
59 | | - pos_attr: (nbits, q enabled, indices) """ |
60 | | - w_attr = w_attr or {} |
61 | | - attr_weights_configs_mapping = { |
62 | | - k: AttributeQuantizationConfig(weights_n_bits=v[0], enable_weights_quantization=v[1]) |
63 | | - for k, v in w_attr.items() |
64 | | - } |
65 | | - qc = QuantizationConfig() |
66 | | - # positional attrs are set via default weight config (so all pos attrs have the same q config) |
67 | | - op_cfg = OpQuantizationConfig( |
68 | | - # canonical names (as 'kernel') |
69 | | - attr_weights_configs_mapping=attr_weights_configs_mapping, |
70 | | - activation_n_bits=a_nbits, |
71 | | - enable_activation_quantization=a_enable, |
72 | | - default_weight_attr_config=AttributeQuantizationConfig(weights_n_bits=pos_attr[0], |
73 | | - enable_weights_quantization=pos_attr[1]), |
74 | | - activation_quantization_method=QuantizationMethod.POWER_OF_TWO, |
75 | | - quantization_preserving=False, |
76 | | - supported_input_activation_n_bits=[2, 4, 8], |
77 | | - fixed_scale=None, |
78 | | - fixed_zero_point=None, |
79 | | - simd_size=None, |
80 | | - signedness=Signedness.AUTO |
81 | | - ) |
82 | | - a_qcfg = NodeActivationQuantizationConfig(qc=qc, op_cfg=op_cfg, |
83 | | - activation_quantization_fn=None, |
84 | | - activation_quantization_params_fn=None) |
85 | | - # full names from the layers |
86 | | - attr_names = [full_attr_name(k) for k in w_attr.keys()] |
87 | | - w_qcfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg, |
88 | | - weights_channels_axis=None, |
89 | | - node_attrs_list=attr_names + list(pos_attr[2])) |
90 | | - qc = CandidateNodeQuantizationConfig(activation_quantization_cfg=a_qcfg, |
91 | | - weights_quantization_cfg=w_qcfg) |
92 | | - |
93 | | - # we generate q configs via constructors to follow the real code as closely as reasonably possible. |
94 | | - # verify that we actually got the configurations we want |
95 | | - assert qc.activation_quantization_cfg.activation_n_bits == a_nbits |
96 | | - assert qc.activation_quantization_cfg.enable_activation_quantization is a_enable |
97 | | - for k, v in w_attr.items(): |
98 | | - # get_attr_config accepts canonical attr names |
99 | | - assert qc.weights_quantization_cfg.get_attr_config(k).weights_n_bits == v[0] |
100 | | - assert qc.weights_quantization_cfg.get_attr_config(k).enable_weights_quantization == v[1] |
101 | | - for pos in pos_attr[2]: |
102 | | - assert qc.weights_quantization_cfg.get_attr_config(pos).weights_n_bits == pos_attr[0] |
103 | | - assert qc.weights_quantization_cfg.get_attr_config(pos).enable_weights_quantization == pos_attr[1] |
104 | | - |
105 | | - return qc |
106 | | - |
107 | | - |
108 | | -class DummyLayer: |
109 | | - """ Only needed for repr(node) to work. """ |
110 | | - pass |
111 | | - |
112 | | - |
113 | | -def build_node(name='node', canonical_weights: dict=None, qcs=None, input_shape=(4, 5, 6), output_shape=(4, 5, 6), |
114 | | - layer_class=DummyLayer, reuse=False): |
115 | | - """ Build a node for tests. |
116 | | - Canonical weights are converted into full unique names. |
117 | | - candidate_quantization_cfg is set is qcs is passed.""" |
118 | | - weights = canonical_weights or {} |
119 | | - weights = {k if isinstance(k, int) else full_attr_name(k): w for k, w in weights.items()} |
120 | | - node = BaseNode(name=name, |
121 | | - framework_attr={}, |
122 | | - input_shape=input_shape, |
123 | | - output_shape=output_shape, |
124 | | - weights=weights, |
125 | | - layer_class=layer_class, |
126 | | - reuse=reuse) |
127 | | - if qcs: |
128 | | - node.candidates_quantization_cfg = qcs |
129 | | - return node |
130 | | - |
| 35 | +from tests_pytest.test_util.graph_builder_utils import build_node, build_qc, full_attr_name |
131 | 36 |
|
132 | 37 | BM = BitwidthMode |
133 | 38 | TIC = TargetInclusionCriterion |
|
0 commit comments