Skip to content

Commit 8ec2fe6

Browse files
irenabirenab
authored andcommitted
add tests for combining weights from activation and weight nodes in virtual node construction
1 parent 31bae84 commit 8ec2fe6

File tree

6 files changed

+188
-20
lines changed

6 files changed

+188
-20
lines changed

model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,21 @@ def __init__(self,
140140
"""
141141
# Validate weights node
142142
kernel_attrs = fw_info.get_kernel_op_attributes(weights_node.type)
143-
assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, 'Expected exactly one kernel attr.'
143+
assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, f'Expected exactly one kernel attr, {kernel_attrs}'
144144
kernel_attr = kernel_attrs[0]
145145
conf_weights = [attr for attr in weights_node.weights if weights_node.is_configurable_weight(attr)]
146146
if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(kernel_attr):
147-
raise NotImplementedError('Only kernel weight can be configurable.') # pragma: no cover
147+
raise NotImplementedError(f'Only kernel weight can be configurable. Got configurable {conf_weights}.')
148148

149-
weights = weights_node.weights
149+
weights = weights_node.weights.copy()
150150
act_node_w_rename = {}
151151
if act_node.weights:
152-
assert fw_info.get_kernel_op_attributes(act_node)[0] is None, \
153-
f'Node {act_node} with kernel cannot be used as activation for VirtualActivationWeightsNode.'
152+
if not fw_info.get_kernel_op_attributes(act_node)[0] is None:
153+
raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
154+
f'VirtualActivationWeightsNode.')
154155
if act_node.has_any_configurable_weight():
155-
raise NotImplementedError('Node with a configurable weight cannot be used as activation for '
156-
'VirtualActivationWeightsNode.') # pragma: no cover
156+
raise NotImplementedError(f'Node {act_node} with a configurable weight cannot be used as activation for '
157+
'VirtualActivationWeightsNode.')
157158
# combine weights from activation and weights
158159
for w_id, w in act_node.weights.items():
159160
if w_id not in weights and not (isinstance(w_id, str) and kernel_attr in w_id):

tests_pytest/_test_util/graph_builder_utils.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Union, Iterable
15+
from typing import Union, Iterable, List
1616

1717
from mct_quantizers import QuantizationMethod
1818
from model_compression_toolkit.core import QuantizationConfig
@@ -31,13 +31,20 @@ class DummyLayer:
3131
pass
3232

3333

34-
def build_node(name='node', canonical_weights: dict=None, qcs=None, input_shape=(4, 5, 6), output_shape=(4, 5, 6),
34+
def build_node(name='node', canonical_weights: dict = None, final_weights: dict = None,
35+
qcs: List[CandidateNodeQuantizationConfig] = None,
36+
input_shape=(4, 5, 6), output_shape=(4, 5, 6),
3537
layer_class=DummyLayer, reuse=False):
3638
""" Build a node for tests.
37-
Canonical weights are converted into full unique names.
39+
Either 'canonical_weights' (to be used by default) or 'final_weights' should be passed.
40+
Canonical weights are converted into full unique names, that contain the canonical name as a substring.
41+
Final weights are used as is.
3842
candidate_quantization_cfg is set is qcs is passed."""
39-
weights = canonical_weights or {}
40-
weights = {k if isinstance(k, int) else full_attr_name(k): w for k, w in weights.items()}
43+
assert canonical_weights is None or final_weights is None
44+
if canonical_weights:
45+
weights = {k if isinstance(k, int) else full_attr_name(k): w for k, w in canonical_weights.items()}
46+
else:
47+
weights = final_weights or {}
4148
node = BaseNode(name=name,
4249
framework_attr={},
4350
input_shape=input_shape,
@@ -46,6 +53,7 @@ def build_node(name='node', canonical_weights: dict=None, qcs=None, input_shape=
4653
layer_class=layer_class,
4754
reuse=reuse)
4855
if qcs:
56+
assert isinstance(qcs, list)
4957
node.candidates_quantization_cfg = qcs
5058
return node
5159

@@ -61,10 +69,24 @@ def full_attr_name(canonical_name: Union[str, dict, Iterable]):
6169
return canonical_name.__class__([convert(name) for name in canonical_name])
6270

6371

64-
def build_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, ())):
65-
""" Build quantization config for tests.
66-
w_attr contains {canonical name: (nbits, q_enabled)}
67-
pos_attr: (nbits, q enabled, indices) """
72+
def build_nbits_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, ()),
73+
convert_canonical_attr=True) -> CandidateNodeQuantizationConfig:
74+
"""
75+
Build quantization config with configurable nbits and enabling/disabling quantization only.
76+
77+
Args:
78+
a_nbits: activation num bits.
79+
a_enable: whether to enable activation quantization.
80+
w_attr: quantization configuration for weight attributes in format {canonical name: (nbits, q_enabled)}.
81+
By default, a canonical weight name is expected and is automatically converted to a dummy full name (that
82+
contains the canonical name as a substring).
83+
Final name can be passed along with convert_canonical_attr=False.
84+
pos_attr: quantization configuration for positional weights in format (nbits, q enabled, indices).
85+
convert_canonical_attr: whether to convert w_attr keys to full names.
86+
87+
Returns:
88+
89+
"""
6890
w_attr = w_attr or {}
6991
attr_weights_configs_mapping = {
7092
k: AttributeQuantizationConfig(weights_n_bits=v[0], enable_weights_quantization=v[1])
@@ -91,7 +113,9 @@ def build_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, ())):
91113
activation_quantization_fn=None,
92114
activation_quantization_params_fn=None)
93115
# full names from the layers
94-
attr_names = [full_attr_name(k) for k in w_attr.keys()]
116+
attr_names = list(w_attr.keys())
117+
if convert_canonical_attr:
118+
attr_names = [full_attr_name(k) for k in w_attr.keys()]
95119
w_qcfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg,
96120
weights_channels_axis=None,
97121
node_attrs_list=attr_names + list(pos_attr[2]))
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import numpy as np
16+
import pytest
17+
18+
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode
19+
from tests_pytest._test_util.graph_builder_utils import build_node, DummyLayer, build_nbits_qc
20+
21+
22+
class DummyLayerWKernel:
23+
pass
24+
25+
26+
class TestVirtualActivationWeightsNode:
27+
# TODO tests only cover combining weights from activation and weight nodes and errors.
28+
def test_activation_with_weights(self, fw_info_mock):
29+
""" Tests that weights from activation and weight node are combined correctly. """
30+
# Each node has a unique weight attr and a unique positional weights. In addition, both nodes have
31+
# an identical canonical attribute (but different full name), and an identical positional weight.
32+
# All weights have different quantization.
33+
a_node = build_node('a', layer_class=DummyLayer,
34+
final_weights={'aaweightaa': np.ones((3, 14)), 'foo': np.ones(15),
35+
1: np.ones(15), 2: np.ones((5, 9))},
36+
qcs=[build_nbits_qc(a_nbits=5,
37+
w_attr={'aaweightaa': (2, True), 'foo': (3, True)},
38+
pos_attr=(4, True, [1, 2]),
39+
convert_canonical_attr=False)])
40+
w_node = build_node('w', layer_class=DummyLayerWKernel,
41+
final_weights={'wwweightww': np.ones((2, 71)), 'bar': np.ones(8),
42+
1: np.ones(28), 3: np.ones(18)},
43+
qcs=[build_nbits_qc(a_nbits=6,
44+
w_attr={'wwweightww': (5, True), 'bar': (6, True)},
45+
pos_attr=(7, True, [1, 3]),
46+
convert_canonical_attr=False)])
47+
48+
fw_info_mock.get_kernel_op_attributes = lambda nt: ['weight'] if nt is DummyLayerWKernel else [None]
49+
50+
v_node = VirtualActivationWeightsNode(a_node, w_node, fw_info_mock)
51+
assert len(v_node.weights) == 8
52+
53+
assert len(w_node.weights) == len(a_node.weights) == 4
54+
# weights from weight node are unchanged
55+
for k, v in w_node.weights.items():
56+
assert np.array_equal(v_node.weights.pop(k), v)
57+
# unique weights from activation node are unchanged
58+
assert np.array_equal(v_node.weights.pop('foo'), a_node.weights['foo'])
59+
assert np.array_equal(v_node.weights.pop(2), a_node.weights[2])
60+
# duplicate positional weight
61+
assert np.array_equal(v_node.weights.pop(101), a_node.weights[1])
62+
# duplicate weight attribute
63+
[(new_attr, w)] = v_node.weights.items()
64+
assert 'weight' not in new_attr
65+
assert np.array_equal(w, a_node.weights['aaweightaa'])
66+
67+
assert len(v_node.candidates_quantization_cfg) == 1
68+
v_qc = v_node.candidates_quantization_cfg[0]
69+
v_attr_cfg = v_qc.weights_quantization_cfg.attributes_config_mapping
70+
v_pos_cfg = v_qc.weights_quantization_cfg.pos_attributes_config_mapping
71+
a_qc = a_node.candidates_quantization_cfg[0]
72+
w_qc = w_node.candidates_quantization_cfg[0]
73+
74+
assert v_attr_cfg == {
75+
'wwweightww': w_qc.weights_quantization_cfg.attributes_config_mapping['wwweightww'],
76+
'bar': w_qc.weights_quantization_cfg.attributes_config_mapping['bar'],
77+
'foo': a_qc.weights_quantization_cfg.attributes_config_mapping['foo'],
78+
new_attr: a_qc.weights_quantization_cfg.attributes_config_mapping['aaweightaa']
79+
}
80+
assert v_pos_cfg == {
81+
1: w_qc.weights_quantization_cfg.pos_attributes_config_mapping[1],
82+
101: a_qc.weights_quantization_cfg.pos_attributes_config_mapping[1],
83+
2: a_qc.weights_quantization_cfg.pos_attributes_config_mapping[2],
84+
3: w_qc.weights_quantization_cfg.pos_attributes_config_mapping[3]
85+
}
86+
87+
def test_invalid_configurable_w_node_weight(self, fw_info_mock):
88+
w_node = build_node('w', layer_class=DummyLayerWKernel,
89+
canonical_weights={'kernel': np.ones(3), 'foo': np.ones(14)},
90+
qcs=[
91+
build_nbits_qc(w_attr={'kernel': (8, True), 'foo': (8, True)}),
92+
build_nbits_qc(w_attr={'kernel': (8, True), 'foo': (4, True)})
93+
])
94+
a_node = build_node('a', qcs=[build_nbits_qc()])
95+
96+
fw_info_mock.get_kernel_op_attributes = lambda nt: ['kernel'] if nt is DummyLayerWKernel else [None]
97+
98+
with pytest.raises(NotImplementedError, match='Only kernel weight can be configurable. Got configurable .*foo'):
99+
VirtualActivationWeightsNode(a_node, w_node, fw_info_mock)
100+
101+
def test_invalid_a_node_configurable_weight(self, fw_info_mock):
102+
w_node = build_node('w', layer_class=DummyLayerWKernel,
103+
canonical_weights={'kernel': np.ones(3), 'foo': np.ones(14)},
104+
qcs=[
105+
build_nbits_qc(w_attr={'kernel': (8, True), 'foo': (8, True)}),
106+
build_nbits_qc(w_attr={'kernel': (4, True), 'foo': (8, True)})
107+
])
108+
a_node = build_node('aaa', canonical_weights={'bar': np.ones(3), 'baz': np.ones(14)},
109+
qcs=[
110+
build_nbits_qc(w_attr={'bar': (8, True), 'baz': (8, True)}),
111+
build_nbits_qc(w_attr={'bar': (8, True), 'baz': (4, True)})
112+
])
113+
fw_info_mock.get_kernel_op_attributes = lambda nt: ['kernel'] if nt is DummyLayerWKernel else [None]
114+
115+
with pytest.raises(NotImplementedError, match='Node .*aaa with a configurable weight cannot be used as '
116+
'activation for VirtualActivationWeightsNode'):
117+
VirtualActivationWeightsNode(a_node, w_node, fw_info_mock)
118+
119+
def test_invalid_a_node_kernel(self, fw_info_mock):
120+
w_node = build_node('w', layer_class=DummyLayerWKernel, canonical_weights={'weight': np.ones(3)},
121+
qcs=[build_nbits_qc(w_attr={'weight': (8, True)})])
122+
a_node = build_node('aaa', canonical_weights={'kernel': np.ones(3)},
123+
qcs=[build_nbits_qc(w_attr={'kernel': (8, True)})])
124+
fw_info_mock.get_kernel_op_attributes = lambda nt: ['weight'] if nt is DummyLayerWKernel else ['kernel']
125+
126+
with pytest.raises(NotImplementedError, match='Node .*aaa with kernel cannot be used as '
127+
'activation for VirtualActivationWeightsNode'):
128+
VirtualActivationWeightsNode(a_node, w_node, fw_info_mock)
129+

tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
RUTarget
3333
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
3434
Utilization, ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
35-
from tests_pytest._test_util.graph_builder_utils import build_node, build_qc, full_attr_name
35+
from tests_pytest._test_util.graph_builder_utils import build_node, full_attr_name, build_nbits_qc as build_qc
3636

3737
BM = BitwidthMode
3838
TIC = TargetInclusionCriterion
@@ -844,7 +844,7 @@ def test_compute_w_utilization_non_custom(self, prepare_compute_w_util):
844844

845845
def test_compute_w_utilization_no_targets(self, graph_mock, fw_impl_mock, fw_info_mock):
846846
graph_mock.nodes = [
847-
build_node('n1', qcs=build_qc()),
847+
build_node('n1', qcs=[build_qc()]),
848848
build_node('n2', canonical_weights={'foo': np.ones((5,))}, qcs=[build_qc(w_attr={'foo': (8, True)})])
849849
]
850850
ru_calc = ResourceUtilizationCalculator(graph_mock, fw_impl_mock, fw_info_mock)

tests_pytest/common_tests/unit_tests/test_model_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from model_compression_toolkit.core.common.graph.edge import Edge
2424
from model_compression_toolkit.core.common.hessian import HessianInfoService
2525
from model_compression_toolkit.core.common.model_collector import create_stats_collector_for_node, create_tensor2node, ModelCollector
26-
from tests_pytest._test_util.graph_builder_utils import build_node, DummyLayer, build_qc
26+
from tests_pytest._test_util.graph_builder_utils import build_node, DummyLayer, build_nbits_qc as build_qc
2727

2828

2929
@pytest.fixture

0 commit comments

Comments
 (0)