|
| 1 | +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +import pytest |
| 16 | + |
| 17 | +import model_compression_toolkit as mct |
| 18 | +from model_compression_toolkit.constants import PYTORCH |
| 19 | +from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter |
| 20 | +from model_compression_toolkit.core.common.quantization.bit_width_config import ManualBitWidthSelection, ManualWeightsBitWidthSelection |
| 21 | +from model_compression_toolkit.core import BitWidthConfig, CoreConfig |
| 22 | + |
| 23 | +from model_compression_toolkit.core.common import Graph |
| 24 | +from model_compression_toolkit.core.common.graph.edge import Edge |
| 25 | +from tests_pytest._test_util.graph_builder_utils import build_node |
| 26 | + |
| 27 | +from model_compression_toolkit.core.common.quantization.set_node_quantization_config import \ |
| 28 | + set_quantization_configuration_to_graph |
| 29 | +from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import \ |
| 30 | + FrameworkQuantizationCapabilities, OperationsSetToLayers |
| 31 | + |
| 32 | +from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO |
| 33 | +from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO |
| 34 | + |
| 35 | +import torch |
| 36 | +from torch import nn |
| 37 | +from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ |
| 38 | + AttachTpcToPytorch |
| 39 | + |
| 40 | +from model_compression_toolkit.core import QuantizationConfig |
| 41 | +from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner |
| 42 | +from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO |
| 43 | +from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation |
| 44 | + |
| 45 | +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs, generate_tpc |
| 46 | + |
| 47 | + |
| 48 | +#TEST_KERNEL = 'kernel' |
| 49 | +#TEST_BIAS = 'bias' |
| 50 | + |
| 51 | +### dummy layer classes |
| 52 | +class Conv2D: |
| 53 | + pass |
| 54 | +class InputLayer: |
| 55 | + pass |
| 56 | +class Add: |
| 57 | + pass |
| 58 | +class BatchNormalization: |
| 59 | + pass |
| 60 | +class ReLU: |
| 61 | + pass |
| 62 | +class Flatten: |
| 63 | + pass |
| 64 | +class Dense: |
| 65 | + pass |
| 66 | + |
| 67 | +#from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs |
| 68 | + |
| 69 | +from tests.pytorch_tests.tpc_pytorch import get_mp_activation_pytorch_tpc_dict |
| 70 | +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR |
| 71 | + |
| 72 | +from tests.common_tests.helpers.generate_test_tpc import generate_tpc_with_activation_mp |
| 73 | + |
| 74 | +import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema |
| 75 | + |
| 76 | +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS |
| 77 | +def get_tpc(kernel_n, bias_n): |
| 78 | + #kernel_weights_n_bits = 8 ### [DEBUG0404] 8 ni suruto Error. 16 dato ugoku. |
| 79 | + #bias_weights_n_bits = 32 |
| 80 | + #activation_n_bits = 8 |
| 81 | + |
| 82 | + base_cfg, _, default_config = get_op_quantization_configs() |
| 83 | + """ |
| 84 | + base_config = base_cfg.clone_and_edit(attr_weights_configs_mapping= |
| 85 | + { |
| 86 | + KERNEL_ATTR: base_cfg.attr_weights_configs_mapping[KERNEL_ATTR] |
| 87 | + .clone_and_edit(weights_n_bits=kernel_weights_n_bits), |
| 88 | + BIAS_ATTR: base_cfg.attr_weights_configs_mapping[BIAS_ATTR] |
| 89 | + .clone_and_edit(weights_n_bits=bias_weights_n_bits, enable_weights_quantization=True), |
| 90 | + }, |
| 91 | + activation_n_bits=activation_n_bits) |
| 92 | + """ |
| 93 | + weights_04_bits = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 4}}) |
| 94 | + weights_02_bits = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 2}}) |
| 95 | + weights_16_bits = base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: 16}}) |
| 96 | + |
| 97 | + mx_cfg_list = [base_cfg, weights_04_bits, weights_02_bits, weights_16_bits] |
| 98 | + tpc = generate_tpc(default_config, base_cfg, mx_cfg_list, 'imx500_tpc_kai') |
| 99 | + |
| 100 | + return tpc |
| 101 | + |
| 102 | + |
| 103 | +# AttributeQuantizationConfig(weights_quantization_method=<QuantizationMethod.SYMMETRIC: 2>, weights_n_bits=16, weights_per_channel_threshold=True, enable_weights_quantization=True, lut_values_bitwidth=None) |
| 104 | +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig |
| 105 | +from tests.common_tests.helpers.generate_test_tpc import generate_test_tpc |
| 106 | + |
| 107 | +### test model |
| 108 | +def get_test_graph(kernel_n, bias_n): |
| 109 | + n1 = build_node('input', layer_class=InputLayer) |
| 110 | + conv1 = build_node('conv1', layer_class=Conv2D, |
| 111 | + canonical_weights={ |
| 112 | + KERNEL_ATTR: AttributeQuantizationConfig(weights_n_bits=8), |
| 113 | + BIAS_ATTR: AttributeQuantizationConfig(weights_n_bits=32)} |
| 114 | + ) |
| 115 | + add1 = build_node('add1', layer_class=Add) |
| 116 | + conv2 = build_node('conv2', layer_class=Conv2D, |
| 117 | + canonical_weights={ |
| 118 | + KERNEL_ATTR: AttributeQuantizationConfig(weights_n_bits=8), |
| 119 | + BIAS_ATTR: AttributeQuantizationConfig(weights_n_bits=32)} |
| 120 | + ) |
| 121 | + bn1 = build_node('bn1', layer_class=BatchNormalization) |
| 122 | + relu = build_node('relu1', layer_class=ReLU, |
| 123 | + canonical_weights={ |
| 124 | + KERNEL_ATTR: AttributeQuantizationConfig(weights_n_bits=8), |
| 125 | + BIAS_ATTR: AttributeQuantizationConfig(weights_n_bits=32)} |
| 126 | + ) |
| 127 | + add2 = build_node('add2', layer_class=Add) |
| 128 | + flatten = build_node('flatten', layer_class=Flatten) |
| 129 | + fc = build_node('fc', layer_class=Dense) |
| 130 | + |
| 131 | + graph = Graph('xyz', input_nodes=[n1], |
| 132 | + nodes=[conv1,add1, conv2, bn1, relu, add2, flatten], |
| 133 | + output_nodes=[fc], |
| 134 | + edge_list=[Edge(n1, conv1, 0, 0), |
| 135 | + Edge(conv1, add1, 0, 0), |
| 136 | + Edge(add1, conv2, 0, 0), |
| 137 | + Edge(conv2, bn1, 0, 0), |
| 138 | + Edge(bn1, relu, 0, 0), |
| 139 | + Edge(relu, add2, 0, 0), |
| 140 | + Edge(add1, add2, 0, 0), |
| 141 | + Edge(add2, flatten, 0, 0), |
| 142 | + Edge(flatten, fc, 0, 0), |
| 143 | + ] |
| 144 | + ) |
| 145 | + |
| 146 | + tpc = get_tpc(kernel_n, bias_n) |
| 147 | + #tpc = mct.get_target_platform_capabilities('pytorch', 'default') |
| 148 | + #print('tpc', tpc) |
| 149 | + #print(type(tpc)) |
| 150 | + #for val in tpc: |
| 151 | + # print(val) |
| 152 | + #print('a'+1) |
| 153 | + fqc = FrameworkQuantizationCapabilities(tpc) |
| 154 | + graph.set_fqc(fqc) |
| 155 | + |
| 156 | + fw_info = DEFAULT_PYTORCH_INFO |
| 157 | + graph.set_fw_info(fw_info) |
| 158 | + return graph |
| 159 | + |
| 160 | +class TestManualWeightsBitwidthSelection: |
| 161 | + # test case for set_manual_activation_bit_width |
| 162 | + test_input_1 = (NodeTypeFilter(Conv2D), 8, KERNEL_ATTR) |
| 163 | + test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16], [KERNEL_ATTR]) |
| 164 | + test_input_3 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [4, 8], [KERNEL_ATTR, BIAS_ATTR]) |
| 165 | + |
| 166 | + test_expected_1 = (NodeTypeFilter, ReLU, 16) |
| 167 | + test_expected_2 = ([NodeTypeFilter, ReLU, 2], [NodeNameFilter, "conv1", 2]) |
| 168 | + test_expected_3 = ([NodeTypeFilter, ReLU, 4], [NodeNameFilter, "conv1", 8]) |
| 169 | + |
| 170 | + @pytest.mark.parametrize(("inputs", "expected"), [ |
| 171 | + (test_input_1, test_expected_1), |
| 172 | + #(test_input_2, test_expected_2), |
| 173 | + #(test_input_3, test_expected_3), |
| 174 | + ]) |
| 175 | + def test_manual_weights_bitwidth_selection(self, inputs, expected): |
| 176 | + print('# test_manual_weights_bitwidth_selection start.') |
| 177 | + |
| 178 | + print('inputs', inputs) |
| 179 | + print('expected', expected) |
| 180 | + |
| 181 | + kernel_n = 8 |
| 182 | + bias_n = 32 |
| 183 | + if KERNEL_ATTR in inputs[2]: |
| 184 | + indices = [index for index, value in enumerate(inputs[2]) if value == KERNEL_ATTR] |
| 185 | + kernel_n = inputs[1] if type(inputs[2]) != list else inputs[1][indices[0]] |
| 186 | + if BIAS_ATTR in inputs[2]: |
| 187 | + indices = [index for index, value in enumerate(inputs[2]) if value == BIAS_ATTR] |
| 188 | + bias_n = inputs[1] if type(inputs[2]) != list else inputs[1][indices[0]] |
| 189 | + print('kernel_n, bias_n', kernel_n, bias_n) |
| 190 | + graph = get_test_graph(kernel_n, bias_n) |
| 191 | + #graph = get_test_graph() |
| 192 | + print('graph', graph) |
| 193 | + core_config = CoreConfig() |
| 194 | + |
| 195 | + core_config.bit_width_config.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) |
| 196 | + |
| 197 | + updated_graph = set_quantization_configuration_to_graph( |
| 198 | + graph, core_config.quantization_config, core_config.bit_width_config, |
| 199 | + False, False |
| 200 | + ) |
| 201 | + print('------graph---------------------') |
| 202 | + print('0', graph) |
| 203 | + print('1', graph.nodes) |
| 204 | + print('2', graph.nodes.keys()) |
| 205 | + """ |
| 206 | + for n in graph.nodes: |
| 207 | + print('n', n) |
| 208 | + a = graph.get_weights_configurable_nodes(DEFAULT_PYTORCH_INFO, True) |
| 209 | + b = graph.get_activation_configurable_nodes() |
| 210 | + print('a', a) |
| 211 | + print('b', b) |
| 212 | +
|
| 213 | + ### len(node.candidates_quantization_cfg) de Error. |
| 214 | + for node in updated_graph.nodes: |
| 215 | + print("z", node) #, node.candidates_quantization_cfg |
| 216 | + for ii in range(len(node.candidates_quantization_cfg)): |
| 217 | + print('z4', ii, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping) |
| 218 | + #print('z4 0', ii, type(node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping)) |
| 219 | + for vkey in node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping: |
| 220 | + #print('z5', vkey, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey]) |
| 221 | + cfg = node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey] |
| 222 | + print('z5 cfg.weights_n_bits', cfg.weights_n_bits) |
| 223 | + """ |
| 224 | + |
| 225 | + print('------updated graph---------------------') |
| 226 | + print(updated_graph) |
| 227 | + |
| 228 | + for node in updated_graph.nodes: |
| 229 | + print("z", node) #, node.candidates_quantization_cfg |
| 230 | + for ii in range(len(node.candidates_quantization_cfg)): |
| 231 | + print('z4', ii, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping) |
| 232 | + #print('z4 0', ii, type(node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping)) |
| 233 | + for vkey in node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping: |
| 234 | + #print('z5', vkey, node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey]) |
| 235 | + cfg = node.candidates_quantization_cfg[ii].weights_quantization_cfg.attributes_config_mapping[vkey] |
| 236 | + print('z5 cfg.weights_n_bits', cfg.weights_n_bits) |
| 237 | + |
| 238 | + """ |
| 239 | + for val2 in node.get_node_weights_attributes(): |
| 240 | + print("z2", val2, type(val2)) |
| 241 | + a = node.weights[val2] |
| 242 | + print('a', a) |
| 243 | + """ |
| 244 | + |
| 245 | + #assert graph == updated_graph |
| 246 | + |
| 247 | + |
| 248 | + print('# test_manual_weights_bitwidth_selection end.') |
| 249 | + |
| 250 | + pass |
| 251 | + |
0 commit comments