Skip to content

Commit 2011098

Browse files
author
reuvenp
committed
merge from main
2 parents d097e77 + 1056e36 commit 2011098

File tree

7 files changed

+127
-23
lines changed

7 files changed

+127
-23
lines changed

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,22 @@ def get_final_activation_config(self) -> List[Tuple[BaseNode, int]]:
740740
sorted_conf_activation = self.get_sorted_activation_configurable_nodes()
741741
return [(n, n.final_activation_quantization_cfg.activation_n_bits) for n in sorted_conf_activation]
742742

743+
def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode:
744+
"""
745+
For a node with quantization_preserving == True, get the previous non-quantization_preserving node
746+
to get activation quantization config from. If quantization_preserving is False return node.
747+
Args:
748+
node: quantization preserving node.
749+
750+
Returns:
751+
The node that the quantization preserving node should get the activation quantization from.
752+
753+
"""
754+
while node.is_quantization_preserving():
755+
prev_nodes = self.get_prev_nodes(node)
756+
assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
757+
node = prev_nodes[0]
758+
return node
743759

744760
def has_any_configurable_activation(self) -> bool:
745761
"""

model_compression_toolkit/core/common/graph/base_node.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,19 @@ def is_activation_quantization_enabled(self) -> bool:
131131
qc.activation_quantization_cfg.enable_activation_quantization
132132
return self.candidates_quantization_cfg[0].activation_quantization_cfg.enable_activation_quantization
133133

134+
def is_quantization_preserving(self) -> bool:
135+
"""
136+
Returns: Whether node activation quantization information is preserved from its inputs.
137+
"""
138+
if self.final_activation_quantization_cfg:
139+
# if we have a final configuration, then we only care to check if it enables activation quantization.
140+
return self.final_activation_quantization_cfg.quantization_preserving
141+
142+
for qc in self.candidates_quantization_cfg:
143+
assert self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving == \
144+
qc.activation_quantization_cfg.quantization_preserving
145+
return self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving
146+
134147
def is_weights_quantization_enabled(self, attr_name: str) -> bool:
135148
"""
136149
Checks whether a node's weights attribute quantization is enabled.

model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,35 @@ def compute_activations_utilization(self,
335335
"""
336336
return self.compute_activation_utilization_by_cut(target_criterion, bitwidth_mode, act_qcs)
337337

338+
def _extract_qc(self, n: BaseNode, act_qcs: Optional[ActivationQCfgPerNode] = None
339+
) -> Union[NodeActivationQuantizationConfig, None]:
340+
"""
341+
Extract quantization config the activation configs dictionary is provided. If node is quantization
342+
preserving, extract the quantization config from the preceding activation quantized node (i.e.
343+
the Quantization the original node preserves).
344+
345+
Args:
346+
n: Node to extract qc for.
347+
act_qcs: custom activations quantization configuration. If not provided, the default
348+
configuration will be extracted from the node.
349+
350+
Returns:
351+
The relevant quantization config.
352+
"""
353+
if act_qcs:
354+
assert not (n.is_quantization_preserving() and act_qcs.get(n.name) is not None), \
355+
f"Quantization preserving node {n.name} should not have a qc for this computation."
356+
return act_qcs.get(self.graph.retrieve_preserved_quantization_node(n).name)
357+
return None
358+
338359
def compute_activation_utilization_by_cut(self,
339360
target_criterion: TargetInclusionCriterion,
340361
bitwidth_mode: BitwidthMode,
341362
act_qcs: Optional[ActivationQCfgPerNode] = None) \
342363
-> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]:
343364
"""
344-
Compute graph activation cuts utilization.
365+
Compute graph activation cuts utilization. If activation quantization configs are provided, then for
366+
quantization preserving nodes, get the previous quantized activation node bit-width.
345367
346368
Args:
347369
target_criterion: criterion to include weights for computation.
@@ -369,7 +391,7 @@ def compute_activation_utilization_by_cut(self,
369391
if not cut_target_nodes:
370392
continue
371393
for n in cut_target_nodes:
372-
qc = act_qcs.get(n.name) if act_qcs else None
394+
qc = self._extract_qc(n, act_qcs)
373395
util_per_cut_per_node[cut][n.name] = self.compute_node_activation_tensor_utilization(n, target_criterion,
374396
bitwidth_mode, qc)
375397
util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore
@@ -384,7 +406,8 @@ def compute_activation_tensors_utilization(self,
384406
include_reused=False) \
385407
-> Tuple[float, Dict[NodeName, Utilization]]:
386408
"""
387-
Compute resource utilization for graph's activations tensors.
409+
Compute resource utilization for graph's activations tensors. If activation quantization configs are provided, then for
410+
quantization preserving nodes, get the previous quantized activation node bit-width.
388411
389412
Args:
390413
target_criterion: criterion to include weights for computation.
@@ -405,7 +428,7 @@ def compute_activation_tensors_utilization(self,
405428

406429
util_per_node: Dict[NodeName, Utilization] = {}
407430
for n in self._topo_sort(nodes):
408-
qc = act_qcs.get(n.name) if act_qcs else None
431+
qc = self._extract_qc(n, act_qcs)
409432
util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc)
410433
util_per_node[n.name] = util
411434

@@ -659,7 +682,7 @@ def _get_target_activation_nodes(self,
659682
if target_criterion == TargetInclusionCriterion.QConfigurable:
660683
nodes = [n for n in nodes if n.has_configurable_activation()]
661684
elif target_criterion == TargetInclusionCriterion.AnyQuantized:
662-
nodes = [n for n in nodes if n.is_activation_quantization_enabled()]
685+
nodes = [n for n in nodes if n.is_activation_quantization_enabled() or n.is_quantization_preserving()]
663686
elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
664687
nodes = [n for n in nodes if n.is_activation_quantization_enabled() and not n.has_configurable_activation()]
665688
elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover
@@ -668,8 +691,7 @@ def _get_target_activation_nodes(self,
668691
nodes = [n for n in nodes if not n.reuse]
669692
return nodes
670693

671-
@classmethod
672-
def _get_activation_nbits(cls,
694+
def _get_activation_nbits(self,
673695
n: BaseNode,
674696
bitwidth_mode: BitwidthMode,
675697
act_qc: Optional[NodeActivationQuantizationConfig]) -> int:
@@ -690,21 +712,22 @@ def _get_activation_nbits(cls,
690712
assert bitwidth_mode == BitwidthMode.QCustom
691713
return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH
692714

693-
if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled():
715+
if bitwidth_mode == BitwidthMode.Float or not (n.is_activation_quantization_enabled() or
716+
n.is_quantization_preserving()):
694717
return FLOAT_BITWIDTH
695718

696719
if bitwidth_mode == BitwidthMode.Q8Bit:
697720
return 8
698721

699-
if bitwidth_mode in cls._bitwidth_mode_fn:
722+
if bitwidth_mode in self._bitwidth_mode_fn:
700723
candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg]
701-
return cls._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)
724+
return self._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)
702725

703726
if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
704-
qcs = n.get_unique_activation_candidates()
727+
qcs = self.graph.retrieve_preserved_quantization_node(n).get_unique_activation_candidates()
705728
if len(qcs) != 1:
706729
raise ValueError(f'Could not retrieve the activation quantization candidate for node {n} '
707-
f'as it has {len(qcs)}!=1 unique candidates .')
730+
f'as it has {len(qcs)}!=1 unique candidates.')
708731
return qcs[0].activation_quantization_cfg.activation_n_bits
709732

710733
raise ValueError(f'Unknown mode {bitwidth_mode}') # pragma: no cover

tests_pytest/_test_util/graph_builder_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def full_attr_name(canonical_name: Union[str, dict, Iterable]):
7070

7171

7272
def build_nbits_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, ()),
73-
convert_canonical_attr=True) -> CandidateNodeQuantizationConfig:
73+
convert_canonical_attr=True, q_preserving=False) -> CandidateNodeQuantizationConfig:
7474
"""
7575
Build quantization config with configurable nbits and enabling/disabling quantization only.
7676
@@ -87,6 +87,8 @@ def build_nbits_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, (
8787
Returns:
8888
8989
"""
90+
assert not(a_enable and q_preserving)
91+
9092
w_attr = w_attr or {}
9193
attr_weights_configs_mapping = {
9294
k: AttributeQuantizationConfig(weights_n_bits=v[0], enable_weights_quantization=v[1])
@@ -102,7 +104,7 @@ def build_nbits_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, (
102104
default_weight_attr_config=AttributeQuantizationConfig(weights_n_bits=pos_attr[0],
103105
enable_weights_quantization=pos_attr[1]),
104106
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
105-
quantization_preserving=False,
107+
quantization_preserving=q_preserving,
106108
supported_input_activation_n_bits=[2, 4, 8],
107109
fixed_scale=None,
108110
fixed_zero_point=None,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from model_compression_toolkit.core.common import Graph
16+
from model_compression_toolkit.core.common.graph.edge import Edge
17+
18+
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc
19+
20+
21+
class TestQuantizationPreservingNode:
22+
23+
def test_activation_preserving_candidate(self):
24+
""" Tests that the correct activation quantization candidate is selected. """
25+
n1 = build_node('qact_node', qcs=[build_nbits_qc()])
26+
n2 = build_node('qp1a_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
27+
n3 = build_node('qp1b_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
28+
n4 = build_node('qp2a_node', qcs=[build_nbits_qc()])
29+
n5 = build_node('qp2b_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
30+
graph = Graph('g', input_nodes=[n1], nodes=[n2, n4], output_nodes=[n3, n5],
31+
edge_list=[Edge(n1, n2, 0, 0), Edge(n2, n3, 0, 0),
32+
Edge(n1, n4, 0, 0), Edge(n4, n5, 0, 0)])
33+
34+
assert graph.retrieve_preserved_quantization_node(n2) is n1
35+
assert graph.retrieve_preserved_quantization_node(n3) is n1
36+
assert graph.retrieve_preserved_quantization_node(n4) is n4
37+
assert graph.retrieve_preserved_quantization_node(n5) is n4

tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
BM = BitwidthMode
3838
TIC = TargetInclusionCriterion
3939

40+
_identity_func = lambda x: x
41+
4042

4143
class TestUtilization:
4244
def test_operations(self):
@@ -296,8 +298,10 @@ class TestComputeActivationTensorsUtilization:
296298
""" Tests for activation tensors utilization public apis. """
297299
def test_compute_node_activation_tensor_utilization(self, graph_mock, fw_impl_mock, fw_info_mock):
298300
mp_reuse = build_node('mp_reuse', output_shape=(None, 3, 14), qcs=[build_qc(4), build_qc(16)], reuse=True)
301+
qp = build_node('qp', output_shape=(None, 15, 9), qcs=[build_qc(a_enable=False, q_preserving=True)])
299302
noq = build_node('noq', output_shape=(None, 15, 9), qcs=[build_qc(a_enable=False)])
300-
graph_mock.nodes = [mp_reuse, noq]
303+
graph_mock.nodes = [mp_reuse, qp, noq]
304+
graph_mock.retrieve_preserved_quantization_node = lambda n: mp_reuse if n is qp else n
301305

302306
ru_calc = ResourceUtilizationCalculator(graph_mock, fw_impl_mock, fw_info_mock)
303307
# _get_activation_nbits is already fully checked, just make sure we use it, and use correctly
@@ -310,6 +314,9 @@ def test_compute_node_activation_tensor_utilization(self, graph_mock, fw_impl_mo
310314
# reused is not ignored
311315
res = ru_calc.compute_node_activation_tensor_utilization(mp_reuse, TIC.QConfigurable, BM.QMinBit)
312316
assert res == Utilization(42, 21.)
317+
# quantization preserving uses custom_qc.
318+
res = ru_calc.compute_node_activation_tensor_utilization(qp, TIC.AnyQuantized, BM.QCustom, custom_qc)
319+
assert res == Utilization(135, 270.)
313320
# not a target node
314321
res = ru_calc.compute_node_activation_tensor_utilization(noq, TIC.AnyQuantized, BM.QCustom, custom_qc)
315322
assert res == Utilization(0, 0)
@@ -391,11 +398,14 @@ def test_compute_cuts_integration(self, graph_mock, fw_impl_mock, fw_info_mock,
391398
""" Test integration with max cut computation. """
392399
# Test a simple linear dummy graph with the real max cut computation.
393400
n1 = build_node('n1', qcs=[build_qc()], input_shape=(None, 10, 20, 3), output_shape=(None, 10, 20, 3))
401+
n1_qp = build_node('n1_qp', qcs=[build_qc(a_enable=False, q_preserving=True)],
402+
input_shape=(None, 10, 20, 3), output_shape=(None, 10, 20, 3))
394403
n2 = build_node('n2', qcs=[build_qc()], input_shape=(None, 10, 20, 3), output_shape=(None, 5, 10))
395404
n3 = build_node('n3', qcs=[build_qc()], input_shape=(None, 5, 10), output_shape=(None, 5, 10))
396405
n4 = build_node('n4', qcs=[build_qc()], input_shape=(None, 5, 10, 32), output_shape=(None, 5, 10, 32))
397-
edges = [Edge(n1, n2, 0, 0), Edge(n2, n3, 0, 0), Edge(n3, n4, 0, 0)]
398-
graph = Graph('g', input_nodes=[n1], nodes=[n2, n3], output_nodes=[n4], edge_list=edges)
406+
edges = [Edge(n1, n1_qp, 0, 0), Edge(n1_qp, n2, 0, 0),
407+
Edge(n2, n3, 0, 0), Edge(n3, n4, 0, 0)]
408+
graph = Graph('g', input_nodes=[n1], nodes=[n1_qp, n2, n3], output_nodes=[n4], edge_list=edges)
399409
ru_calc = ResourceUtilizationCalculator(graph, fw_impl_mock, fw_info_mock)
400410
# wrap the real implementation
401411
maxcut_spy = mocker.patch('model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.'
@@ -405,11 +415,11 @@ def test_compute_cuts_integration(self, graph_mock, fw_impl_mock, fw_info_mock,
405415
cuts_cache = ru_calc.cuts
406416

407417
# verify the cache
408-
assert len(cuts_cache) == 5
418+
assert len(cuts_cache) == 6
409419
assert all(isinstance(k, Cut) for k in cuts_cache.keys())
410420
# for each cut we save a list of its nodes
411421
cuts_nodes = {tuple(sorted(n.name for n in nodes)) for nodes in cuts_cache.values()}
412-
assert cuts_nodes == {('n1',), ('n4',), ('n1', 'n2'), ('n2', 'n3'), ('n3', 'n4')}
422+
assert cuts_nodes == {('n1',), ('n4',), ('n1', 'n1_qp'), ('n1_qp', 'n2'), ('n2', 'n3'), ('n3', 'n4')}
413423

414424
# verify cuts computation only happens the first time
415425
cuts_cache2 = ru_calc.cuts
@@ -420,7 +430,8 @@ def test_compute_cuts_integration(self, graph_mock, fw_impl_mock, fw_info_mock,
420430
nodes_to_cuts = {tuple(sorted(elem.node_name for elem in cut.mem_elements.elements)): cut
421431
for cut in cuts_cache.keys()}
422432
cut1 = nodes_to_cuts[('n1',)]
423-
cut12 = nodes_to_cuts[('n1', 'n2')]
433+
cut11 = nodes_to_cuts[('n1', 'n1_qp')]
434+
cut12 = nodes_to_cuts[('n1_qp', 'n2')]
424435
cut23 = nodes_to_cuts[('n2', 'n3')]
425436
cut34 = nodes_to_cuts[('n3', 'n4')]
426437
cut4 = nodes_to_cuts[('n4',)]
@@ -430,7 +441,8 @@ def test_compute_cuts_integration(self, graph_mock, fw_impl_mock, fw_info_mock,
430441
bitwidth_mode=BM.QDefaultSP)
431442

432443
assert per_cut_per_node == {cut1: {'n1': Utilization(10 * 20 * 3, 600)},
433-
cut12: {'n1': Utilization(10 * 20 * 3, 600),
444+
cut11: {'n1': Utilization(10 * 20 * 3, 600), 'n1_qp': Utilization(10 * 20 * 3, 600)},
445+
cut12: {'n1_qp': Utilization(10 * 20 * 3, 600),
434446
'n2': Utilization(5 * 10, 50)},
435447
cut23: {'n2': Utilization(5*10, 50),
436448
'n3': Utilization(5*10, 50)},
@@ -439,7 +451,8 @@ def test_compute_cuts_integration(self, graph_mock, fw_impl_mock, fw_info_mock,
439451
cut4: {'n4': Utilization(5 * 10 * 32, 1600)}}
440452
assert per_cut == {
441453
nodes_to_cuts[('n1',)]: Utilization(600, 600),
442-
nodes_to_cuts[('n1', 'n2')]: Utilization(650, 650),
454+
nodes_to_cuts[('n1', 'n1_qp')]: Utilization(1200, 1200),
455+
nodes_to_cuts[('n1_qp', 'n2')]: Utilization(650, 650),
443456
nodes_to_cuts[('n2', 'n3')]: Utilization(100, 100),
444457
nodes_to_cuts[('n3', 'n4')]: Utilization(1650, 1650),
445458
nodes_to_cuts[('n4',)]: Utilization(1600, 1600)

tests_pytest/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def minimal_tpc():
3131
@fixture
3232
def graph_mock():
3333
""" Basic Graph mock. """
34-
return Mock(spec_set=Graph, nodes=[])
34+
return Mock(spec_set=Graph, nodes=[], retrieve_preserved_quantization_node=lambda x: x)
3535

3636

3737
@fixture

0 commit comments

Comments
 (0)