Skip to content

Commit 32480a5

Browse files
irenabirenab
authored andcommitted
add ru calculator integration test for torch
1 parent d10483b commit 32480a5

File tree

6 files changed

+431
-215
lines changed

6 files changed

+431
-215
lines changed
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import abc
16+
import copy
17+
from typing import Callable, Generator
18+
19+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo
20+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21+
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
22+
VirtualSplitActivationNode, VirtualSplitWeightsNode
23+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
24+
RUTarget, ResourceUtilization
25+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
26+
ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
27+
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
28+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
29+
from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig, \
30+
OpQuantizationConfig, QuantizationConfigOptions, Signedness, OperatorsSet, OperatorSetNames, \
31+
TargetPlatformCapabilities
32+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
33+
34+
35+
def build_tpc():
36+
""" Build minimal tpc containing linear and binary ops, configurable a+w for linear ops,
37+
distinguishable nbits for default / linear / binary activation and for const / linear weights. """
38+
default_w_cfg = AttributeQuantizationConfig(weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
39+
weights_n_bits=8,
40+
weights_per_channel_threshold=True,
41+
enable_weights_quantization=True)
42+
default_w_nbit = 16
43+
default_a_nbit = 8
44+
default_op_cfg = OpQuantizationConfig(
45+
default_weight_attr_config=default_w_cfg.clone_and_edit(weights_n_bits=default_w_nbit),
46+
attr_weights_configs_mapping={},
47+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
48+
activation_n_bits=default_a_nbit,
49+
supported_input_activation_n_bits=[16, 8, 4, 2],
50+
enable_activation_quantization=True,
51+
quantization_preserving=False,
52+
fixed_scale=None,
53+
fixed_zero_point=None,
54+
simd_size=32,
55+
signedness=Signedness.AUTO)
56+
57+
default_w_op_cfg = default_op_cfg.clone_and_edit(
58+
attr_weights_configs_mapping={KERNEL_ATTR: default_w_cfg, BIAS_ATTR: AttributeQuantizationConfig()}
59+
)
60+
61+
mp_configs = []
62+
linear_w_min_nbit = 2
63+
linear_a_min_nbit = 4
64+
for w_nbit in [linear_w_min_nbit, linear_w_min_nbit * 2, linear_w_min_nbit * 4]:
65+
for a_nbit in [linear_a_min_nbit, linear_a_min_nbit * 2]:
66+
attr_cfg = default_w_cfg.clone_and_edit(weights_n_bits=w_nbit)
67+
mp_configs.append(default_w_op_cfg.clone_and_edit(
68+
attr_weights_configs_mapping={KERNEL_ATTR: attr_cfg, BIAS_ATTR: AttributeQuantizationConfig()},
69+
activation_n_bits=a_nbit
70+
))
71+
mp_cfg_options = QuantizationConfigOptions(quantization_configurations=mp_configs,
72+
base_config=default_w_op_cfg)
73+
74+
linear_ops = [OperatorsSet(name=opset, qc_options=mp_cfg_options) for opset in (OperatorSetNames.CONV,
75+
OperatorSetNames.CONV_TRANSPOSE,
76+
OperatorSetNames.DEPTHWISE_CONV,
77+
OperatorSetNames.FULLY_CONNECTED)]
78+
79+
default_cfg = QuantizationConfigOptions(quantization_configurations=[default_op_cfg])
80+
81+
binary_out_a_bit = 16
82+
binary_cfg = QuantizationConfigOptions(
83+
quantization_configurations=[default_op_cfg.clone_and_edit(activation_n_bits=binary_out_a_bit)]
84+
)
85+
binary_ops = [
86+
OperatorsSet(name=opset, qc_options=binary_cfg) for opset in (OperatorSetNames.ADD, OperatorSetNames.SUB)
87+
]
88+
89+
tpc = TargetPlatformCapabilities(default_qco=default_cfg,
90+
tpc_platform_type='test',
91+
operator_set=linear_ops + binary_ops,
92+
fusing_patterns=None)
93+
94+
assert linear_w_min_nbit != default_w_nbit
95+
assert len({linear_a_min_nbit, default_a_nbit, binary_out_a_bit}) == 3
96+
return tpc, linear_w_min_nbit, linear_a_min_nbit, default_w_nbit, default_a_nbit, binary_out_a_bit
97+
98+
99+
class BaseRUIntegrationTester(abc.ABC):
100+
""" Test resource utilization calculator on a real framework model with graph preparation """
101+
fw_info: FrameworkInfo
102+
fw_impl: FrameworkImplementation
103+
attach_to_fw_func: Callable
104+
105+
bhwc_input_shape = (1, 18, 18, 3)
106+
107+
@abc.abstractmethod
108+
def _build_sequential_model(self):
109+
r""" build framework model for test_orig_vs_virtual_sequential_graph:
110+
conv2d(k=5, filters=8) -> add const(14, 14, 8) -> dwconv(k=3, dm=2) ->
111+
-> conv_transpose(k=5, filters=12) -> flatten -> fc(10)
112+
"""
113+
raise NotImplementedError()
114+
115+
@abc.abstractmethod
116+
def _build_mult_output_activation_model(self):
117+
r""" build framework model for test_mult_output_activation:
118+
x - conv2d(k=3, filters=15, groups=3) \ subtract -> flatten -> fc(10)
119+
\ dwconv2d(k=3, dm=5) /
120+
"""
121+
raise NotImplementedError()
122+
123+
@abc.abstractmethod
124+
def _data_gen(self) -> Generator:
125+
""" Build framework datagen with 'bhwc_input_shape' """
126+
raise NotImplementedError()
127+
128+
def test_orig_vs_virtual_graph_ru(self):
129+
""" Test detailed ru computation on original and the corresponding virtual graphs. """
130+
# model is sequential so that activation cuts are well uniquely defined.
131+
model = self._build_sequential_model()
132+
# test the original graph
133+
graph, nbits = self._prepare_graph(model, disable_linear_collapse=True)
134+
linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits
135+
136+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
137+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
138+
BitwidthMode.QMinBit,
139+
return_detailed=True)
140+
141+
exp_cuts_ru = [18 * 18 * 3 * default_a_nbit / 8,
142+
(18 * 18 * 3 * default_a_nbit + 14 * 14 * 8 * linear_a_min_nbit) / 8,
143+
(14 * 14 * 8 * linear_a_min_nbit + 14 * 14 * 8 * binary_out_a_bit) / 8,
144+
(14 * 14 * 8 * binary_out_a_bit + 12 * 12 * 16 * linear_a_min_nbit) / 8,
145+
(12 * 12 * 16 * linear_a_min_nbit + 16 * 16 * 12 * linear_a_min_nbit) / 8,
146+
(16 * 16 * 12 * linear_a_min_nbit + 16 * 16 * 12 * default_a_nbit) / 8,
147+
(16 * 16 * 12 * default_a_nbit + 10 * linear_a_min_nbit) / 8,
148+
10 * linear_a_min_nbit / 8]
149+
assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)
150+
151+
exp_w_ru = [5 * 5 * 3 * 8 * linear_w_min_nbit / 8,
152+
14 * 14 * 8 * default_w_nbits / 8, # const
153+
3 * 3 * 8 * 2 * linear_w_min_nbit / 8,
154+
5 * 5 * 16 * 12 * linear_w_min_nbit / 8,
155+
(16*16*12) * 10 * linear_w_min_nbit / 8]
156+
assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru
157+
158+
exp_bops = [(5 * 5 * 3 * 8) * (14 * 14) * default_a_nbit * linear_w_min_nbit,
159+
(3 * 3 * 8 * 2) * (12 * 12) * binary_out_a_bit * linear_w_min_nbit,
160+
(5 * 5 * 16 * 12) * (16 * 16) * linear_a_min_nbit * linear_w_min_nbit,
161+
(16 * 16 * 12) * 10 * default_a_nbit * linear_w_min_nbit]
162+
assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops
163+
164+
assert ru_orig == ResourceUtilization(activation_memory=max(exp_cuts_ru),
165+
weights_memory=sum(exp_w_ru),
166+
total_memory=max(exp_cuts_ru) + sum(exp_w_ru),
167+
bops=sum(exp_bops))
168+
169+
# generate virtual graph and make sure resource utilization results are identical
170+
virtual_graph = substitute(copy.deepcopy(graph),
171+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
172+
assert len(virtual_graph.nodes) == 7
173+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 4
174+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3
175+
176+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
177+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
178+
BitwidthMode.QMinBit,
179+
return_detailed=True)
180+
assert ru_virtual == ru_orig
181+
182+
assert (self._extract_values(detailed_virtual[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru))
183+
# virtual composed node contains both activation's const and weights' kernel
184+
assert (self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) ==
185+
[exp_w_ru[0], sum(exp_w_ru[1:3]), *exp_w_ru[3:]])
186+
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops
187+
188+
def test_mult_output_activation(self):
189+
""" Tests the case when input activation has multiple outputs -> virtual weights nodes are not merged
190+
into VirtualActivationWeightsNode. """
191+
model = self._build_mult_output_activation_model()
192+
193+
graph, nbits = self._prepare_graph(model)
194+
linear_w_min_nbit, linear_a_min_nbit, default_w_nbits, default_a_nbit, binary_out_a_bit = nbits
195+
196+
ru_calc = ResourceUtilizationCalculator(graph, self.fw_impl, self.fw_info)
197+
ru_orig, detailed_orig = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
198+
BitwidthMode.QMinBit,
199+
return_detailed=True)
200+
201+
exp_cuts_ru = [18*18*3 * default_a_nbit/8,
202+
(18*18*3 * default_a_nbit + 16*16*15 * linear_a_min_nbit) / 8,
203+
(18*18*3 * default_a_nbit + 2 * (16*16*15 * linear_a_min_nbit)) / 8,
204+
16*16*15 * (2*linear_a_min_nbit + binary_out_a_bit) / 8,
205+
(16*16*15 * (binary_out_a_bit + default_a_nbit)) / 8,
206+
(16*16*15 * default_a_nbit + 10 * linear_a_min_nbit) / 8,
207+
10 * linear_a_min_nbit / 8]
208+
209+
# the order of conv and dwconv is not guaranteed, but they have same values
210+
exp_w_ru = [3*3*1*15 * linear_w_min_nbit/8,
211+
3*3*3*5 * linear_w_min_nbit/8,
212+
16*16*15*10 * linear_w_min_nbit/8]
213+
# bops are not computed for virtual weights nodes
214+
exp_bops = [(16*16*15*10)*default_a_nbit*linear_w_min_nbit]
215+
216+
assert self._extract_values(detailed_orig[RUTarget.ACTIVATION], sort=True) == sorted(exp_cuts_ru)
217+
assert self._extract_values(detailed_orig[RUTarget.WEIGHTS]) == exp_w_ru
218+
assert self._extract_values(detailed_orig[RUTarget.BOPS]) == exp_bops
219+
220+
virtual_graph = substitute(copy.deepcopy(graph),
221+
self.fw_impl.get_substitutions_virtual_weights_activation_coupling())
222+
assert len(virtual_graph.nodes) == 8
223+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualActivationWeightsNode)]) == 1
224+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitActivationNode)]) == 3
225+
assert len([n for n in virtual_graph.nodes if isinstance(n, VirtualSplitWeightsNode)]) == 2
226+
227+
ru_calc = ResourceUtilizationCalculator(virtual_graph, self.fw_impl, self.fw_info)
228+
ru_virtual, detailed_virtual = ru_calc.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
229+
BitwidthMode.QMinBit,
230+
return_detailed=True)
231+
assert ru_virtual == ru_orig
232+
# conv and dwconv each remain as a pair of virtual W and virtual A nodes. Remaining virtual W nodes mess up the
233+
# cuts - but this should only add virtualW-virtualA cuts, all cuts from the original graph should stay identical
234+
assert not set(exp_cuts_ru) - set(detailed_virtual[RUTarget.ACTIVATION].values())
235+
assert self._extract_values(detailed_virtual[RUTarget.WEIGHTS]) == exp_w_ru
236+
assert self._extract_values(detailed_virtual[RUTarget.BOPS]) == exp_bops
237+
238+
def _prepare_graph(self, model, disable_linear_collapse: bool=False):
239+
# If disable_linear_collapse is False we use the default quantization config
240+
tpc, *nbits = build_tpc()
241+
qcfg = QuantizationConfig(linear_collapsing=False) if disable_linear_collapse else QuantizationConfig()
242+
graph = graph_preparation_runner(model,
243+
self._data_gen,
244+
quantization_config=qcfg,
245+
fw_info=self.fw_info,
246+
fw_impl=self.fw_impl,
247+
fqc=self.attach_to_fw_func(tpc),
248+
mixed_precision_enable=True,
249+
running_gptq=False)
250+
return graph, nbits
251+
252+
@staticmethod
253+
def _extract_values(res: dict, sort=False):
254+
""" Extract values for target detailed resource utilization result. """
255+
values = list(res.values())
256+
return sorted(values) if sort else values

0 commit comments

Comments
 (0)