|
| 1 | +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +import abc |
| 16 | +import copy |
| 17 | +from typing import Callable, Generator |
| 18 | + |
| 19 | +from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo |
| 20 | +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation |
| 21 | +from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \ |
| 22 | + VirtualSplitActivationNode, VirtualSplitWeightsNode |
| 23 | +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ |
| 24 | + RUTarget, ResourceUtilization |
| 25 | +from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \ |
| 26 | + ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode |
| 27 | +from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute |
| 28 | +from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner |
| 29 | +from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig, \ |
| 30 | + OpQuantizationConfig, QuantizationConfigOptions, Signedness, OperatorsSet, OperatorSetNames, \ |
| 31 | + TargetPlatformCapabilities |
| 32 | +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR |
| 33 | + |
| 34 | + |
| 35 | +def build_tpc(): |
| 36 | + """ Build minimal tpc containing linear and binary ops, configurable a+w for linear ops, |
| 37 | + distinguishable nbits for default / linear / binary activation and for const / linear weights. """ |
| 38 | + default_w_cfg = AttributeQuantizationConfig(weights_quantization_method=QuantizationMethod.POWER_OF_TWO, |
| 39 | + weights_n_bits=8, |
| 40 | + weights_per_channel_threshold=True, |
| 41 | + enable_weights_quantization=True) |
| 42 | + default_w_nbit = 16 |
| 43 | + default_a_nbit = 8 |
| 44 | + default_op_cfg = OpQuantizationConfig( |
| 45 | + default_weight_attr_config=default_w_cfg.clone_and_edit(weights_n_bits=default_w_nbit), |
| 46 | + attr_weights_configs_mapping={}, |
| 47 | + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, |
| 48 | + activation_n_bits=default_a_nbit, |
| 49 | + supported_input_activation_n_bits=[16, 8, 4, 2], |
| 50 | + enable_activation_quantization=True, |
| 51 | + quantization_preserving=False, |
| 52 | + fixed_scale=None, |
| 53 | + fixed_zero_point=None, |
| 54 | + simd_size=32, |
| 55 | + signedness=Signedness.AUTO) |
| 56 | + |
| 57 | + default_w_op_cfg = default_op_cfg.clone_and_edit( |
| 58 | + attr_weights_configs_mapping={KERNEL_ATTR: default_w_cfg, BIAS_ATTR: AttributeQuantizationConfig()} |
| 59 | + ) |
| 60 | + |
| 61 | + mp_configs = [] |
| 62 | + linear_w_min_nbit = 2 |
| 63 | + linear_a_min_nbit = 4 |
| 64 | + for w_nbit in [linear_w_min_nbit, linear_w_min_nbit * 2, linear_w_min_nbit * 4]: |
| 65 | + for a_nbit in [linear_a_min_nbit, linear_a_min_nbit * 2]: |
| 66 | + attr_cfg = default_w_cfg.clone_and_edit(weights_n_bits=w_nbit) |
| 67 | + mp_configs.append(default_w_op_cfg.clone_and_edit( |
| 68 | + attr_weights_configs_mapping={KERNEL_ATTR: attr_cfg, BIAS_ATTR: AttributeQuantizationConfig()}, |
| 69 | + activation_n_bits=a_nbit |
| 70 | + )) |
| 71 | + mp_cfg_options = QuantizationConfigOptions(quantization_configurations=mp_configs, |
| 72 | + base_config=default_w_op_cfg) |
| 73 | + |
| 74 | + linear_ops = [OperatorsSet(name=opset, qc_options=mp_cfg_options) for opset in (OperatorSetNames.CONV, |
| 75 | + OperatorSetNames.CONV_TRANSPOSE, |
| 76 | + OperatorSetNames.DEPTHWISE_CONV, |
| 77 | + OperatorSetNames.FULLY_CONNECTED)] |
| 78 | + |
| 79 | + default_cfg = QuantizationConfigOptions(quantization_configurations=[default_op_cfg]) |
| 80 | + |
| 81 | + binary_out_a_bit = 16 |
| 82 | + binary_cfg = QuantizationConfigOptions( |
| 83 | + quantization_configurations=[default_op_cfg.clone_and_edit(activation_n_bits=binary_out_a_bit)] |
| 84 | + ) |
| 85 | + binary_ops = [ |
| 86 | + OperatorsSet(name=opset, qc_options=binary_cfg) for opset in (OperatorSetNames.ADD, OperatorSetNames.SUB) |
| 87 | + ] |
| 88 | + |
| 89 | + tpc = TargetPlatformCapabilities(default_qco=default_cfg, |
| 90 | + tpc_platform_type='test', |
| 91 | + operator_set=linear_ops + binary_ops, |
| 92 | + fusing_patterns=None) |
| 93 | + |
| 94 | + assert linear_w_min_nbit != default_w_nbit |
| 95 | + assert len({linear_a_min_nbit, default_a_nbit, binary_out_a_bit}) == 3 |
| 96 | + return tpc, linear_w_min_nbit, linear_a_min_nbit, default_w_nbit, default_a_nbit, binary_out_a_bit |
| 97 | + |
| 98 | + |
| 99 | +class BaseRUIntegrationTester(abc.ABC): |
| 100 | + """ Test resource utilization calculator on a real framework model with graph preparation """ |
| 101 | + fw_info: FrameworkInfo |
| 102 | + fw_impl: FrameworkImplementation |
| 103 | + attach_to_fw_func: Callable |
| 104 | + |
| 105 | + bhwc_input_shape = (1, 18, 18, 3) |
| 106 | + |
| 107 | + @abc.abstractmethod |
| 108 | + def _build_sequential_model(self): |
| 109 | + r""" build framework model for test_orig_vs_virtual_sequential_graph: |
| 110 | + conv2d(k=5, filters=8) -> add const(14, 14, 8) -> dwconv(k=3, dm=2) -> |
| 111 | + -> conv_transpose(k=5, filters=12) -> flatten -> fc(10) |
| 112 | + """ |
| 113 | + raise NotImplementedError() |
| 114 | + |
| 115 | + @abc.abstractmethod |
| 116 | + def _build_mult_output_activation_model(self): |
| 117 | + r""" build framework model for test_mult_output_activation: |
| 118 | + x - conv2d(k=3, filters=15, groups=3) \ subtract -> flatten -> fc(10) |
| 119 | + \ dwconv2d(k=3, dm=5) / |
| 120 | + """ |
| 121 | + raise NotImplementedError() |
| 122 | + |
| 123 | + @abc.abstractmethod |
| 124 | + def _data_gen(self) -> Generator: |
| 125 | + """ Build framework datagen with 'bhwc_input_shape' """ |
| 126 | + raise NotImplementedError() |
| 127 | + |
| 128 | + def test_orig_vs_virtual_graph_ru(self): |
| 129 | + """ Test detailed ru computation on original and the corresponding virtual graphs. """ |
| 130 | + # model is sequential so that activation cuts are well uniquely defined. |
| 131 | + model = self._build_sequential_model() |
| 132 | + # test the original graph |
| 133 | + graph, nbits = self._prepare_graph(model, disable_linear_collapse=True) |
| 134 | + linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits |
| 135 | + |
| 136 | + ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info) |
| 137 | + ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, |
| 138 | + BitwidthMode.QMinBit, |
| 139 | + return_detailed=True) |
| 140 | + |
| 141 | + exp_cuts_ru = [18 * 18 * 3 * default_a_nbit / 8, |
| 142 | + (18 * 18 * 3 * default_a_nbit + 14 * 14 * 8 * linear_a_min_nbit) / 8, |
| 143 | + (14 * 14 * 8 * linear_a_min_nbit + 14 * 14 * 8 * binary_out_a_bit) / 8, |
| 144 | + (14 * 14 * 8 * binary_out_a_bit + 12 * 12 * 16 * linear_a_min_nbit) / 8, |
| 145 | + (12 * 12 * 16 * linear_a_min_nbit + 16 * 16 * 12 * linear_a_min_nbit) / 8, |
| 146 | + (16 * 16 * 12 * linear_a_min_nbit + 16 * 16 * 12 * default_a_nbit) / 8, |
| 147 | + (16 * 16 * 12 * default_a_nbit + 10 * linear_a_min_nbit) / 8, |
| 148 | + 10 * linear_a_min_nbit / 8] |
| 149 | + assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru) |
| 150 | + |
| 151 | + exp_w_ru = [5 * 5 * 3 * 8 * linear_w_min_nbit / 8, |
| 152 | + 14 * 14 * 8 * default_w_nbits / 8, # const |
| 153 | + 3 * 3 * 8 * 2 * linear_w_min_nbit / 8, |
| 154 | + 5 * 5 * 16 * 12 * linear_w_min_nbit / 8, |
| 155 | + (16*16*12) * 10 * linear_w_min_nbit / 8] |
| 156 | + assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru |
| 157 | + |
| 158 | + exp_bops = [(5 * 5 * 3 * 8) * (14 * 14) * default_a_nbit * linear_w_min_nbit, |
| 159 | + (3 * 3 * 8 * 2) * (12 * 12) * binary_out_a_bit * linear_w_min_nbit, |
| 160 | + (5 * 5 * 16 * 12) * (16 * 16) * linear_a_min_nbit * linear_w_min_nbit, |
| 161 | + (16 * 16 * 12) * 10 * default_a_nbit * linear_w_min_nbit] |
| 162 | + assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops |
| 163 | + |
| 164 | + assert ru_orig == ResourceUtilization(activation_memory=max(exp_cuts_ru), |
| 165 | + weights_memory=sum(exp_w_ru), |
| 166 | + total_memory=max(exp_cuts_ru) + sum(exp_w_ru), |
| 167 | + bops=sum(exp_bops)) |
| 168 | + |
| 169 | + # generate virtual graph and make sure resource utilization results are identical |
| 170 | + virtual_graph = substitute(copy.deepcopy(graph), |
| 171 | + self.fw_impl.get_substitutions_virtual_weights_activation_coupling()) |
| 172 | + assert len(virtual_graph.nodes) == 7 |
| 173 | + assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 4 |
| 174 | + assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3 |
| 175 | + |
| 176 | + ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info) |
| 177 | + ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, |
| 178 | + BitwidthMode.QMinBit, |
| 179 | + return_detailed=True) |
| 180 | + assert ru_virtual == ru_orig |
| 181 | + |
| 182 | + assert (self._extract_values(detailed_virtual[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)) |
| 183 | + # virtual composed node contains both activation's const and weights' kernel |
| 184 | + assert (self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) == |
| 185 | + [exp_w_ru[0], sum(exp_w_ru[1:3]), *exp_w_ru[3:]]) |
| 186 | + assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops |
| 187 | + |
| 188 | + def test_sequential_graph_with_default_quant_cfg(self): |
| 189 | + """ Using the default quant config, make sure the original and the virtual graphs yield the same results. """ |
| 190 | + model = self._build_sequential_model() |
| 191 | + graph, *_ = self._prepare_graph(model) |
| 192 | + |
| 193 | + ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info) |
| 194 | + ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.Any, |
| 195 | + BitwidthMode.QMaxBit, |
| 196 | + return_detailed=True) |
| 197 | + |
| 198 | + virtual_graph = substitute(copy.deepcopy(graph), |
| 199 | + self.fw_impl.get_substitutions_virtual_weights_activation_coupling()) |
| 200 | + ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info) |
| 201 | + ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.Any, |
| 202 | + BitwidthMode.QMaxBit, |
| 203 | + return_detailed=True) |
| 204 | + assert ru_orig == ru_virtual |
| 205 | + assert (self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == |
| 206 | + self._extract_values(detailed_virtual[RUTarget.ACTIVATION], sort=True)) |
| 207 | + assert (self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == |
| 208 | + self._extract_values(detailed_virtual[RUTarget.WEIGHTS])) |
| 209 | + assert (self._extract_values(detailed_orig[RUTarget.TOTAL], sort=True) == |
| 210 | + self._extract_values(detailed_virtual[RUTarget.TOTAL], sort=True)) |
| 211 | + assert (self._extract_values(detailed_orig[RUTarget.BOPS]) == |
| 212 | + self._extract_values(detailed_virtual[RUTarget.BOPS])) |
| 213 | + |
| 214 | + def test_mult_output_activation(self): |
| 215 | + """ Tests the case when input activation has multiple outputs -> virtual weights nodes are not merged |
| 216 | + into VirtualActivationWeightsNode. """ |
| 217 | + model = self._build_mult_output_activation_model() |
| 218 | + |
| 219 | + graph, nbits = self._prepare_graph(model) |
| 220 | + linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits |
| 221 | + |
| 222 | + ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info) |
| 223 | + ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, |
| 224 | + BitwidthMode.QMinBit, |
| 225 | + return_detailed=True) |
| 226 | + |
| 227 | + exp_cuts_ru = [18*18*3 * default_a_nbit/8, |
| 228 | + (18*18*3 * default_a_nbit + 16*16*15 * linear_a_min_nbit) / 8, |
| 229 | + (18*18*3 * default_a_nbit + 2 * (16*16*15 * linear_a_min_nbit)) / 8, |
| 230 | + 16*16*15 * (2*linear_a_min_nbit + binary_out_a_bit) / 8, |
| 231 | + (16*16*15 * (binary_out_a_bit + default_a_nbit)) / 8, |
| 232 | + (16*16*15 * default_a_nbit + 10 * linear_a_min_nbit) / 8, |
| 233 | + 10 * linear_a_min_nbit / 8] |
| 234 | + |
| 235 | + # the order of conv and dwconv is not guaranteed, but they have same values |
| 236 | + exp_w_ru = [3*3*1*15 * linear_w_min_nbit/8, |
| 237 | + 3*3*3*5 * linear_w_min_nbit/8, |
| 238 | + 16*16*15*10 * linear_w_min_nbit/8] |
| 239 | + # bops are not computed for virtual weights nodes |
| 240 | + exp_bops = [(16*16*15*10)*default_a_nbit*linear_w_min_nbit] |
| 241 | + |
| 242 | + assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru) |
| 243 | + assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru |
| 244 | + assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops |
| 245 | + |
| 246 | + virtual_graph = substitute(copy.deepcopy(graph), |
| 247 | + self.fw_impl.get_substitutions_virtual_weights_activation_coupling()) |
| 248 | + assert len(virtual_graph.nodes) == 8 |
| 249 | + assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 1 |
| 250 | + assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3 |
| 251 | + assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitWeightsNode)]) == 2 |
| 252 | + |
| 253 | + ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info) |
| 254 | + ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, |
| 255 | + BitwidthMode.QMinBit, |
| 256 | + return_detailed=True) |
| 257 | + assert ru_virtual == ru_orig |
| 258 | + # conv and dwconv each remain as a pair of virtual W and virtual A nodes. Remaining virtual W nodes mess up the |
| 259 | + # cuts - but this should only add virtualW-virtualA cuts, all cuts from the original graph should stay identical |
| 260 | + assert not set(exp_cuts_ru) - set(detailed_virtual[RUTarget.ACTIVATION].values()) |
| 261 | + assert self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) == exp_w_ru |
| 262 | + assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops |
| 263 | + |
| 264 | + def _prepare_graph(self, model, disable_linear_collapse: bool=False): |
| 265 | + # If disable_linear_collapse is False we use the default quantization config |
| 266 | + tpc, *nbits = build_tpc() |
| 267 | + qcfg = QuantizationConfig(linear_collapsing=False) if disable_linear_collapse else QuantizationConfig() |
| 268 | + graph = graph_preparation_runner(model, |
| 269 | + self._data_gen, |
| 270 | + quantization_config=qcfg, |
| 271 | + fw_info=self.fw_info, |
| 272 | + fw_impl=self.fw_impl, |
| 273 | + fqc=self.attach_to_fw_func(tpc), |
| 274 | + mixed_precision_enable=True, |
| 275 | + running_gptq=False) |
| 276 | + return graph, nbits |
| 277 | + |
| 278 | + @staticmethod |
| 279 | + def _extract_values(res: dict, sort=False): |
| 280 | + """ Extract values for target detailed resource utilization result. """ |
| 281 | + values = list(res.values()) |
| 282 | + return sorted(values) if sort else values |
0 commit comments