Skip to content

Commit 0cd32e7

Browse files
irenabirenab
authored andcommitted
add ru calculator integration test for torch
1 parent d10483b commit 0cd32e7

File tree

6 files changed

+462
-215
lines changed

6 files changed

+462
-215
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import abc
16+
17+
from typing import Callable, Generator
18+
19+
import copy
20+
21+
import numpy as np
22+
23+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo
24+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
25+
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
26+
VirtualSplitActivationNode, VirtualSplitWeightsNode
27+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
28+
RUTarget, ResourceUtilization
29+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
30+
ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
31+
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
32+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
33+
from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig, \
34+
OpQuantizationConfig, QuantizationConfigOptions, Signedness, OperatorsSet, OperatorSetGroup, OperatorSetNames, \
35+
Fusing, TargetPlatformCapabilities
36+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
37+
38+
39+
def build_tpc():
40+
default_w_cfg = AttributeQuantizationConfig(weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
41+
weights_n_bits=8,
42+
weights_per_channel_threshold=True,
43+
enable_weights_quantization=True)
44+
default_w_nbit = 16
45+
default_a_nbit = 8
46+
default_op_cfg = OpQuantizationConfig(
47+
default_weight_attr_config=default_w_cfg.clone_and_edit(weights_n_bits=default_w_nbit),
48+
attr_weights_configs_mapping={},
49+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
50+
activation_n_bits=default_a_nbit,
51+
supported_input_activation_n_bits=[16, 8, 4, 2],
52+
enable_activation_quantization=True,
53+
quantization_preserving=False,
54+
fixed_scale=None,
55+
fixed_zero_point=None,
56+
simd_size=32,
57+
signedness=Signedness.AUTO)
58+
59+
default_w_op_cfg = default_op_cfg.clone_and_edit(
60+
attr_weights_configs_mapping={KERNEL_ATTR: default_w_cfg, BIAS_ATTR: AttributeQuantizationConfig()}
61+
)
62+
63+
mp_configs = []
64+
linear_w_min_nbit = 2
65+
linear_a_min_nbit = 4
66+
for w_nbit in [linear_w_min_nbit, linear_w_min_nbit * 2, linear_w_min_nbit * 4]:
67+
for a_nbit in [linear_a_min_nbit, linear_a_min_nbit * 2]:
68+
attr_cfg = default_w_cfg.clone_and_edit(weights_n_bits=w_nbit)
69+
mp_configs.append(default_w_op_cfg.clone_and_edit(
70+
attr_weights_configs_mapping={KERNEL_ATTR: attr_cfg, BIAS_ATTR: AttributeQuantizationConfig()},
71+
activation_n_bits=a_nbit
72+
))
73+
mp_cfg_options = QuantizationConfigOptions(quantization_configurations=mp_configs,
74+
base_config=default_w_op_cfg)
75+
76+
linear_ops = [OperatorsSet(name=opset, qc_options=mp_cfg_options) for opset in (OperatorSetNames.CONV,
77+
OperatorSetNames.CONV_TRANSPOSE,
78+
OperatorSetNames.DEPTHWISE_CONV,
79+
OperatorSetNames.FULLY_CONNECTED)]
80+
81+
default_cfg = QuantizationConfigOptions(quantization_configurations=[default_op_cfg])
82+
relu = OperatorsSet(name=OperatorSetNames.RELU, qc_options=default_cfg)
83+
84+
binary_out_a_bit = 16
85+
binary_cfg = QuantizationConfigOptions(quantization_configurations=[default_op_cfg.clone_and_edit(activation_n_bits=binary_out_a_bit)])
86+
binary_ops = [OperatorsSet(name=opset, qc_options=binary_cfg) for opset in (OperatorSetNames.ADD, OperatorSetNames.SUB)]
87+
88+
fusing_patterns = [Fusing(operator_groups=(OperatorSetGroup(operators_set=linear_ops), relu))]
89+
90+
tpc = TargetPlatformCapabilities(default_qco=default_cfg,
91+
tpc_platform_type='test',
92+
operator_set=linear_ops + binary_ops + [relu],
93+
fusing_patterns=fusing_patterns)
94+
95+
assert linear_w_min_nbit != default_w_nbit
96+
assert len({linear_a_min_nbit, default_a_nbit, binary_out_a_bit}) == 3
97+
return tpc, linear_w_min_nbit, linear_a_min_nbit, default_w_nbit, default_a_nbit, binary_out_a_bit
98+
99+
100+
def data_gen():
101+
yield [np.random.randn(16, 16, 3)]
102+
103+
104+
class BaseRUIntegrationTester(abc.ABC):
105+
""" Test resource utilization calculator on a real framework model with graph preparation """
106+
fw_info: FrameworkInfo
107+
fw_impl: FrameworkImplementation
108+
attach_to_fw_func: Callable
109+
110+
bhwc_input_shape = (1, 18, 18, 3)
111+
112+
@abc.abstractmethod
113+
def _build_sequential_model(self):
114+
r""" build framework model for test_orig_vs_virtual_sequential_graph:
115+
conv2d(k=5, filters=8) -> add const(14, 14, 8) -> dwconv(k=3, dm=2) ->
116+
-> conv_transpose(k=5, filters=12) -> flatten -> fc(10)
117+
"""
118+
raise NotImplementedError()
119+
120+
@abc.abstractmethod
121+
def _build_mult_output_activation_model(self):
122+
r""" build framework model for test_mult_output_activation:
123+
x - conv2d(k=3, filters=15, groups=3) \ subtract -> flatten -> fc(10)
124+
\ dwconv2d(k=3, dm=5) /
125+
"""
126+
raise NotImplementedError()
127+
128+
@abc.abstractmethod
129+
def _data_gen(self) -> Generator:
130+
""" Build framework datagen with 'bhwc_input_shape' """
131+
raise NotImplementedError()
132+
133+
def test_orig_vs_virtual_graph_ru(self):
134+
""" Test detailed ru computation on original and the corresponding virtual graphs. """
135+
# model is sequential so that activation cuts are well uniquely defined.
136+
model = self._build_sequential_model()
137+
# test the original graph
138+
graph, nbits = self._prepare_graph(model, disable_linear_collapse=True)
139+
linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits
140+
141+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
142+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
143+
BitwidthMode.QMinBit,
144+
return_detailed=True)
145+
146+
exp_cuts_ru = [18 * 18 * 3 * default_a_nbit / 8,
147+
(18 * 18 * 3 * default_a_nbit + 14 * 14 * 8 * linear_a_min_nbit) / 8,
148+
(14 * 14 * 8 * linear_a_min_nbit + 14 * 14 * 8 * binary_out_a_bit) / 8,
149+
(14 * 14 * 8 * binary_out_a_bit + 12 * 12 * 16 * linear_a_min_nbit) / 8,
150+
(12 * 12 * 16 * linear_a_min_nbit + 16 * 16 * 12 * linear_a_min_nbit) / 8,
151+
(16 * 16 * 12 * linear_a_min_nbit + 16 * 16 * 12 * default_a_nbit) / 8,
152+
(16 * 16 * 12 * default_a_nbit + 10 * linear_a_min_nbit) / 8,
153+
10 * linear_a_min_nbit / 8]
154+
assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)
155+
156+
exp_w_ru = [5 * 5 * 3 * 8 * linear_w_min_nbit / 8,
157+
14 * 14 * 8 * default_w_nbits / 8, # const
158+
3 * 3 * 8 * 2 * linear_w_min_nbit / 8,
159+
5 * 5 * 16 * 12 * linear_w_min_nbit / 8,
160+
(16*16*12) * 10 * linear_w_min_nbit / 8]
161+
assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru
162+
163+
exp_bops = [(5 * 5 * 3 * 8) * (14 * 14) * default_a_nbit * linear_w_min_nbit,
164+
(3 * 3 * 8 * 2) * (12 * 12) * binary_out_a_bit * linear_w_min_nbit,
165+
(5 * 5 * 16 * 12) * (16 * 16) * linear_a_min_nbit * linear_w_min_nbit,
166+
(16 * 16 * 12) * 10 * default_a_nbit * linear_w_min_nbit]
167+
assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops
168+
169+
assert ru_orig == ResourceUtilization(activation_memory=max(exp_cuts_ru),
170+
weights_memory=sum(exp_w_ru),
171+
total_memory=max(exp_cuts_ru) + sum(exp_w_ru),
172+
bops=sum(exp_bops))
173+
174+
# generate virtual graph and make sure resource utilization results are identical
175+
virtual_graph = substitute(copy.deepcopy(graph),
176+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
177+
assert len(virtual_graph.nodes) == 7
178+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 4
179+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3
180+
181+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
182+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
183+
BitwidthMode.QMinBit,
184+
return_detailed=True)
185+
assert ru_virtual == ru_orig
186+
187+
assert (self._extract_values(detailed_virtual[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru))
188+
# virtual composed node contains both activation's const and weights' kernel
189+
assert (self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) ==
190+
[exp_w_ru[0], sum(exp_w_ru[1:3]), *exp_w_ru[3:]])
191+
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops
192+
193+
def test_sequential_graph_with_default_quant_cfg(self):
194+
""" Using the default quant config, make sure the original and the virtual graphs yield the same results. """
195+
model = self._build_sequential_model()
196+
graph, *_ = self._prepare_graph(model)
197+
198+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
199+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.Any,
200+
BitwidthMode.QMaxBit,
201+
return_detailed=True)
202+
203+
virtual_graph = substitute(copy.deepcopy(graph),
204+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
205+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
206+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.Any,
207+
BitwidthMode.QMaxBit,
208+
return_detailed=True)
209+
assert ru_orig == ru_virtual
210+
assert (self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) ==
211+
self._extract_values(detailed_virtual[RUTarget.ACTIVATION], sort=True))
212+
assert (self._extract_values(detailed_orig[RUTarget.WEIGHTS]) ==
213+
self._extract_values(detailed_virtual[RUTarget.WEIGHTS]))
214+
assert (self._extract_values(detailed_orig[RUTarget.TOTAL], sort=True) ==
215+
self._extract_values(detailed_virtual[RUTarget.TOTAL], sort=True))
216+
assert (self._extract_values(detailed_orig[RUTarget.BOPS]) ==
217+
self._extract_values(detailed_virtual[RUTarget.BOPS]))
218+
219+
def test_mult_output_activation(self):
220+
""" Tests the case when input activation has multiple outputs -> virtual weights nodes are not merged
221+
into VirtualActivationWeightsNode. """
222+
model = self._build_mult_output_activation_model()
223+
224+
graph, nbits = self._prepare_graph(model)
225+
linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits
226+
227+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
228+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
229+
BitwidthMode.QMinBit,
230+
return_detailed=True)
231+
232+
exp_cuts_ru = [18*18*3 * default_a_nbit/8,
233+
(18*18*3 * default_a_nbit + 16*16*15 * linear_a_min_nbit) / 8,
234+
(18*18*3 * default_a_nbit + 2 * (16*16*15 * linear_a_min_nbit)) / 8,
235+
16*16*15 * (2*linear_a_min_nbit + binary_out_a_bit) / 8,
236+
(16*16*15 * (binary_out_a_bit + default_a_nbit)) / 8,
237+
(16*16*15 * default_a_nbit + 10 * linear_a_min_nbit) / 8,
238+
10 * linear_a_min_nbit / 8]
239+
240+
# the order of conv and dwconv is not guaranteed, but they have same values
241+
exp_w_ru = [3*3*1*15 * linear_w_min_nbit/8,
242+
3*3*3*5 * linear_w_min_nbit/8,
243+
16*16*15*10 * linear_w_min_nbit/8]
244+
# bops are not computed for virtual weights nodes
245+
exp_bops = [(16*16*15*10)*default_a_nbit*linear_w_min_nbit]
246+
247+
assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)
248+
assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru
249+
assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops
250+
251+
virtual_graph = substitute(copy.deepcopy(graph),
252+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
253+
assert len(virtual_graph.nodes) == 8
254+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 1
255+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3
256+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitWeightsNode)]) == 2
257+
258+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
259+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
260+
BitwidthMode.QMinBit,
261+
return_detailed=True)
262+
assert ru_virtual == ru_orig
263+
# conv and dwconv each remain as a pair of virtual W and virtual A nodes. Remaining virtual W nodes mess up the
264+
# cuts - but this should only add virtualW-virtualA cuts, all cuts from the original graph should stay identical
265+
assert not set(exp_cuts_ru) - set(detailed_virtual[RUTarget.ACTIVATION].values())
266+
assert self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) == exp_w_ru
267+
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops
268+
269+
def _prepare_graph(self, model, disable_linear_collapse: bool=False):
270+
# If disable_linear_collapse is False we use the default quantization config
271+
tpc, *nbits = build_tpc()
272+
qcfg = QuantizationConfig(linear_collapsing=False) if disable_linear_collapse else QuantizationConfig()
273+
graph = graph_preparation_runner(model,
274+
self._data_gen,
275+
quantization_config=qcfg,
276+
fw_info=self.fw_info,
277+
fw_impl=self.fw_impl,
278+
fqc=self.attach_to_fw_func(tpc),
279+
mixed_precision_enable=True,
280+
running_gptq=False)
281+
return graph, nbits
282+
283+
@staticmethod
284+
def _extract_values(res: dict, sort=False):
285+
""" Extract values for target detailed resource utilization result. """
286+
values = list(res.values())
287+
return sorted(values) if sort else values

0 commit comments

Comments
 (0)