Skip to content

Commit a7e0f45

Browse files
irenabirenab
authored andcommitted
add ru calculator integration test for torch
1 parent d10483b commit a7e0f45

File tree

6 files changed

+462
-214
lines changed

6 files changed

+462
-214
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import abc
16+
import copy
17+
from typing import Callable, Generator
18+
19+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo
20+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21+
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
22+
VirtualSplitActivationNode, VirtualSplitWeightsNode
23+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
24+
RUTarget, ResourceUtilization
25+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
26+
ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
27+
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
28+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
29+
from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig, \
30+
OpQuantizationConfig, QuantizationConfigOptions, Signedness, OperatorsSet, OperatorSetNames, \
31+
TargetPlatformCapabilities
32+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
33+
34+
35+
def build_tpc():
36+
""" Build minimal tpc containing linear and binary ops, configurable a+w for linear ops,
37+
distinguishable nbits for default / linear / binary activation and for const / linear weights. """
38+
default_w_cfg = AttributeQuantizationConfig(weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
39+
weights_n_bits=8,
40+
weights_per_channel_threshold=True,
41+
enable_weights_quantization=True)
42+
default_w_nbit = 16
43+
default_a_nbit = 8
44+
default_op_cfg = OpQuantizationConfig(
45+
default_weight_attr_config=default_w_cfg.clone_and_edit(weights_n_bits=default_w_nbit),
46+
attr_weights_configs_mapping={},
47+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
48+
activation_n_bits=default_a_nbit,
49+
supported_input_activation_n_bits=[16, 8, 4, 2],
50+
enable_activation_quantization=True,
51+
quantization_preserving=False,
52+
fixed_scale=None,
53+
fixed_zero_point=None,
54+
simd_size=32,
55+
signedness=Signedness.AUTO)
56+
57+
default_w_op_cfg = default_op_cfg.clone_and_edit(
58+
attr_weights_configs_mapping={KERNEL_ATTR: default_w_cfg, BIAS_ATTR: AttributeQuantizationConfig()}
59+
)
60+
61+
mp_configs = []
62+
linear_w_min_nbit = 2
63+
linear_a_min_nbit = 4
64+
for w_nbit in [linear_w_min_nbit, linear_w_min_nbit * 2, linear_w_min_nbit * 4]:
65+
for a_nbit in [linear_a_min_nbit, linear_a_min_nbit * 2]:
66+
attr_cfg = default_w_cfg.clone_and_edit(weights_n_bits=w_nbit)
67+
mp_configs.append(default_w_op_cfg.clone_and_edit(
68+
attr_weights_configs_mapping={KERNEL_ATTR: attr_cfg, BIAS_ATTR: AttributeQuantizationConfig()},
69+
activation_n_bits=a_nbit
70+
))
71+
mp_cfg_options = QuantizationConfigOptions(quantization_configurations=mp_configs,
72+
base_config=default_w_op_cfg)
73+
74+
linear_ops = [OperatorsSet(name=opset, qc_options=mp_cfg_options) for opset in (OperatorSetNames.CONV,
75+
OperatorSetNames.CONV_TRANSPOSE,
76+
OperatorSetNames.DEPTHWISE_CONV,
77+
OperatorSetNames.FULLY_CONNECTED)]
78+
79+
default_cfg = QuantizationConfigOptions(quantization_configurations=[default_op_cfg])
80+
81+
binary_out_a_bit = 16
82+
binary_cfg = QuantizationConfigOptions(
83+
quantization_configurations=[default_op_cfg.clone_and_edit(activation_n_bits=binary_out_a_bit)]
84+
)
85+
binary_ops = [
86+
OperatorsSet(name=opset, qc_options=binary_cfg) for opset in (OperatorSetNames.ADD, OperatorSetNames.SUB)
87+
]
88+
89+
tpc = TargetPlatformCapabilities(default_qco=default_cfg,
90+
tpc_platform_type='test',
91+
operator_set=linear_ops + binary_ops,
92+
fusing_patterns=None)
93+
94+
assert linear_w_min_nbit != default_w_nbit
95+
assert len({linear_a_min_nbit, default_a_nbit, binary_out_a_bit}) == 3
96+
return tpc, linear_w_min_nbit, linear_a_min_nbit, default_w_nbit, default_a_nbit, binary_out_a_bit
97+
98+
99+
class BaseRUIntegrationTester(abc.ABC):
100+
""" Test resource utilization calculator on a real framework model with graph preparation """
101+
fw_info: FrameworkInfo
102+
fw_impl: FrameworkImplementation
103+
attach_to_fw_func: Callable
104+
105+
bhwc_input_shape = (1, 18, 18, 3)
106+
107+
@abc.abstractmethod
108+
def _build_sequential_model(self):
109+
r""" build framework model for test_orig_vs_virtual_sequential_graph:
110+
conv2d(k=5, filters=8) -> add const(14, 14, 8) -> dwconv(k=3, dm=2) ->
111+
-> conv_transpose(k=5, filters=12) -> flatten -> fc(10)
112+
"""
113+
raise NotImplementedError()
114+
115+
@abc.abstractmethod
116+
def _build_mult_output_activation_model(self):
117+
r""" build framework model for test_mult_output_activation:
118+
x - conv2d(k=3, filters=15, groups=3) \ subtract -> flatten -> fc(10)
119+
\ dwconv2d(k=3, dm=5) /
120+
"""
121+
raise NotImplementedError()
122+
123+
@abc.abstractmethod
124+
def _data_gen(self) -> Generator:
125+
""" Build framework datagen with 'bhwc_input_shape' """
126+
raise NotImplementedError()
127+
128+
def test_orig_vs_virtual_graph_ru(self):
129+
""" Test detailed ru computation on original and the corresponding virtual graphs. """
130+
# model is sequential so that activation cuts are well uniquely defined.
131+
model = self._build_sequential_model()
132+
# test the original graph
133+
graph, nbits = self._prepare_graph(model, disable_linear_collapse=True)
134+
linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits
135+
136+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
137+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
138+
BitwidthMode.QMinBit,
139+
return_detailed=True)
140+
141+
exp_cuts_ru = [18 * 18 * 3 * default_a_nbit / 8,
142+
(18 * 18 * 3 * default_a_nbit + 14 * 14 * 8 * linear_a_min_nbit) / 8,
143+
(14 * 14 * 8 * linear_a_min_nbit + 14 * 14 * 8 * binary_out_a_bit) / 8,
144+
(14 * 14 * 8 * binary_out_a_bit + 12 * 12 * 16 * linear_a_min_nbit) / 8,
145+
(12 * 12 * 16 * linear_a_min_nbit + 16 * 16 * 12 * linear_a_min_nbit) / 8,
146+
(16 * 16 * 12 * linear_a_min_nbit + 16 * 16 * 12 * default_a_nbit) / 8,
147+
(16 * 16 * 12 * default_a_nbit + 10 * linear_a_min_nbit) / 8,
148+
10 * linear_a_min_nbit / 8]
149+
assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)
150+
151+
exp_w_ru = [5 * 5 * 3 * 8 * linear_w_min_nbit / 8,
152+
14 * 14 * 8 * default_w_nbits / 8, # const
153+
3 * 3 * 8 * 2 * linear_w_min_nbit / 8,
154+
5 * 5 * 16 * 12 * linear_w_min_nbit / 8,
155+
(16*16*12) * 10 * linear_w_min_nbit / 8]
156+
assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru
157+
158+
exp_bops = [(5 * 5 * 3 * 8) * (14 * 14) * default_a_nbit * linear_w_min_nbit,
159+
(3 * 3 * 8 * 2) * (12 * 12) * binary_out_a_bit * linear_w_min_nbit,
160+
(5 * 5 * 16 * 12) * (16 * 16) * linear_a_min_nbit * linear_w_min_nbit,
161+
(16 * 16 * 12) * 10 * default_a_nbit * linear_w_min_nbit]
162+
assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops
163+
164+
assert ru_orig == ResourceUtilization(activation_memory=max(exp_cuts_ru),
165+
weights_memory=sum(exp_w_ru),
166+
total_memory=max(exp_cuts_ru) + sum(exp_w_ru),
167+
bops=sum(exp_bops))
168+
169+
# generate virtual graph and make sure resource utilization results are identical
170+
virtual_graph = substitute(copy.deepcopy(graph),
171+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
172+
assert len(virtual_graph.nodes) == 7
173+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 4
174+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3
175+
176+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
177+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
178+
BitwidthMode.QMinBit,
179+
return_detailed=True)
180+
assert ru_virtual == ru_orig
181+
182+
assert (self._extract_values(detailed_virtual[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru))
183+
# virtual composed node contains both activation's const and weights' kernel
184+
assert (self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) ==
185+
[exp_w_ru[0], sum(exp_w_ru[1:3]), *exp_w_ru[3:]])
186+
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops
187+
188+
def test_sequential_graph_with_default_quant_cfg(self):
189+
""" Using the default quant config, make sure the original and the virtual graphs yield the same results. """
190+
model = self._build_sequential_model()
191+
graph, *_ = self._prepare_graph(model)
192+
193+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
194+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.Any,
195+
BitwidthMode.QMaxBit,
196+
return_detailed=True)
197+
198+
virtual_graph = substitute(copy.deepcopy(graph),
199+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
200+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
201+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.Any,
202+
BitwidthMode.QMaxBit,
203+
return_detailed=True)
204+
assert ru_orig == ru_virtual
205+
assert (self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) ==
206+
self._extract_values(detailed_virtual[RUTarget.ACTIVATION], sort=True))
207+
assert (self._extract_values(detailed_orig[RUTarget.WEIGHTS]) ==
208+
self._extract_values(detailed_virtual[RUTarget.WEIGHTS]))
209+
assert (self._extract_values(detailed_orig[RUTarget.TOTAL], sort=True) ==
210+
self._extract_values(detailed_virtual[RUTarget.TOTAL], sort=True))
211+
assert (self._extract_values(detailed_orig[RUTarget.BOPS]) ==
212+
self._extract_values(detailed_virtual[RUTarget.BOPS]))
213+
214+
def test_mult_output_activation(self):
215+
""" Tests the case when input activation has multiple outputs -> virtual weights nodes are not merged
216+
into VirtualActivationWeightsNode. """
217+
model = self._build_mult_output_activation_model()
218+
219+
graph, nbits = self._prepare_graph(model)
220+
linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits
221+
222+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
223+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
224+
BitwidthMode.QMinBit,
225+
return_detailed=True)
226+
227+
exp_cuts_ru = [18*18*3 * default_a_nbit/8,
228+
(18*18*3 * default_a_nbit + 16*16*15 * linear_a_min_nbit) / 8,
229+
(18*18*3 * default_a_nbit + 2 * (16*16*15 * linear_a_min_nbit)) / 8,
230+
16*16*15 * (2*linear_a_min_nbit + binary_out_a_bit) / 8,
231+
(16*16*15 * (binary_out_a_bit + default_a_nbit)) / 8,
232+
(16*16*15 * default_a_nbit + 10 * linear_a_min_nbit) / 8,
233+
10 * linear_a_min_nbit / 8]
234+
235+
# the order of conv and dwconv is not guaranteed, but they have same values
236+
exp_w_ru = [3*3*1*15 * linear_w_min_nbit/8,
237+
3*3*3*5 * linear_w_min_nbit/8,
238+
16*16*15*10 * linear_w_min_nbit/8]
239+
# bops are not computed for virtual weights nodes
240+
exp_bops = [(16*16*15*10)*default_a_nbit*linear_w_min_nbit]
241+
242+
assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)
243+
assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru
244+
assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops
245+
246+
virtual_graph = substitute(copy.deepcopy(graph),
247+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
248+
assert len(virtual_graph.nodes) == 8
249+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 1
250+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3
251+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitWeightsNode)]) == 2
252+
253+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
254+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
255+
BitwidthMode.QMinBit,
256+
return_detailed=True)
257+
assert ru_virtual == ru_orig
258+
# conv and dwconv each remain as a pair of virtual W and virtual A nodes. Remaining virtual W nodes mess up the
259+
# cuts - but this should only add virtualW-virtualA cuts, all cuts from the original graph should stay identical
260+
assert not set(exp_cuts_ru) - set(detailed_virtual[RUTarget.ACTIVATION].values())
261+
assert self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) == exp_w_ru
262+
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops
263+
264+
def _prepare_graph(self, model, disable_linear_collapse: bool=False):
265+
# If disable_linear_collapse is False we use the default quantization config
266+
tpc, *nbits = build_tpc()
267+
qcfg = QuantizationConfig(linear_collapsing=False) if disable_linear_collapse else QuantizationConfig()
268+
graph = graph_preparation_runner(model,
269+
self._data_gen,
270+
quantization_config=qcfg,
271+
fw_info=self.fw_info,
272+
fw_impl=self.fw_impl,
273+
fqc=self.attach_to_fw_func(tpc),
274+
mixed_precision_enable=True,
275+
running_gptq=False)
276+
return graph, nbits
277+
278+
@staticmethod
279+
def _extract_values(res: dict, sort=False):
280+
""" Extract values for target detailed resource utilization result. """
281+
values = list(res.values())
282+
return sorted(values) if sort else values

0 commit comments

Comments
 (0)