Skip to content

Commit 4a1f731

Browse files
irenabirenab
authored andcommitted
organize tests
1 parent 8bb1c3d commit 4a1f731

File tree

9 files changed

+35
-102
lines changed

9 files changed

+35
-102
lines changed

tests_pytest/common/core/common/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py

Lines changed: 3 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
from types import MethodType
16-
from typing import Iterable, Union
1716
from unittest.mock import Mock
1817

1918
import numpy as np
2019
import pytest
21-
from mct_quantizers import QuantizationMethod
2220

2321
from model_compression_toolkit.constants import FLOAT_BITWIDTH
24-
from model_compression_toolkit.core import QuantizationConfig, ResourceUtilization
25-
from model_compression_toolkit.core.common import Graph, BaseNode
22+
from model_compression_toolkit.core import ResourceUtilization
23+
from model_compression_toolkit.core.common import Graph
2624
from model_compression_toolkit.core.common.graph.edge import Edge
2725
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut
2826
from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
@@ -34,100 +32,7 @@
3432
RUTarget
3533
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
3634
Utilization, ResourceUtilizationCalculator, TargetInclusionCriterion, BitwidthMode
37-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
38-
CandidateNodeQuantizationConfig
39-
from model_compression_toolkit.core.common.quantization.node_quantization_config import \
40-
NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig
41-
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
42-
AttributeQuantizationConfig, Signedness
43-
44-
45-
def full_attr_name(canonical_name: Union[str, dict, Iterable]):
46-
""" Convert canonical attr (such as 'kernel') into a full name originated from the layer (e.g. 'conv2d_1/kernel:0')
47-
We just need the names to differ from canonical to make sure we call the correct apis. We use the same
48-
template for simplicity, so we don't have to explicitly synchronize names between node and weight configs."""
49-
convert = lambda name: f'{name[0]}/{name}/{name[-1]}' if isinstance(name, str) else name
50-
if isinstance(canonical_name, str):
51-
return convert(canonical_name)
52-
assert isinstance(canonical_name, (list, tuple, set))
53-
return canonical_name.__class__([convert(name) for name in canonical_name])
54-
55-
56-
def build_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, ())):
57-
""" Build quantization config for tests.
58-
w_attr contains {canonical name: (nbits, q_enabled)}
59-
pos_attr: (nbits, q enabled, indices) """
60-
w_attr = w_attr or {}
61-
attr_weights_configs_mapping = {
62-
k: AttributeQuantizationConfig(weights_n_bits=v[0], enable_weights_quantization=v[1])
63-
for k, v in w_attr.items()
64-
}
65-
qc = QuantizationConfig()
66-
# positional attrs are set via default weight config (so all pos attrs have the same q config)
67-
op_cfg = OpQuantizationConfig(
68-
# canonical names (as 'kernel')
69-
attr_weights_configs_mapping=attr_weights_configs_mapping,
70-
activation_n_bits=a_nbits,
71-
enable_activation_quantization=a_enable,
72-
default_weight_attr_config=AttributeQuantizationConfig(weights_n_bits=pos_attr[0],
73-
enable_weights_quantization=pos_attr[1]),
74-
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
75-
quantization_preserving=False,
76-
supported_input_activation_n_bits=[2, 4, 8],
77-
fixed_scale=None,
78-
fixed_zero_point=None,
79-
simd_size=None,
80-
signedness=Signedness.AUTO
81-
)
82-
a_qcfg = NodeActivationQuantizationConfig(qc=qc, op_cfg=op_cfg,
83-
activation_quantization_fn=None,
84-
activation_quantization_params_fn=None)
85-
# full names from the layers
86-
attr_names = [full_attr_name(k) for k in w_attr.keys()]
87-
w_qcfg = NodeWeightsQuantizationConfig(qc=qc, op_cfg=op_cfg,
88-
weights_channels_axis=None,
89-
node_attrs_list=attr_names + list(pos_attr[2]))
90-
qc = CandidateNodeQuantizationConfig(activation_quantization_cfg=a_qcfg,
91-
weights_quantization_cfg=w_qcfg)
92-
93-
# we generate q configs via constructors to follow the real code as closely as reasonably possible.
94-
# verify that we actually got the configurations we want
95-
assert qc.activation_quantization_cfg.activation_n_bits == a_nbits
96-
assert qc.activation_quantization_cfg.enable_activation_quantization is a_enable
97-
for k, v in w_attr.items():
98-
# get_attr_config accepts canonical attr names
99-
assert qc.weights_quantization_cfg.get_attr_config(k).weights_n_bits == v[0]
100-
assert qc.weights_quantization_cfg.get_attr_config(k).enable_weights_quantization == v[1]
101-
for pos in pos_attr[2]:
102-
assert qc.weights_quantization_cfg.get_attr_config(pos).weights_n_bits == pos_attr[0]
103-
assert qc.weights_quantization_cfg.get_attr_config(pos).enable_weights_quantization == pos_attr[1]
104-
105-
return qc
106-
107-
108-
class DummyLayer:
109-
""" Only needed for repr(node) to work. """
110-
pass
111-
112-
113-
def build_node(name='node', canonical_weights: dict=None, qcs=None, input_shape=(4, 5, 6), output_shape=(4, 5, 6),
114-
layer_class=DummyLayer, reuse=False):
115-
""" Build a node for tests.
116-
Canonical weights are converted into full unique names.
117-
candidate_quantization_cfg is set is qcs is passed."""
118-
weights = canonical_weights or {}
119-
weights = {k if isinstance(k, int) else full_attr_name(k): w for k, w in weights.items()}
120-
node = BaseNode(name=name,
121-
framework_attr={},
122-
input_shape=input_shape,
123-
output_shape=output_shape,
124-
weights=weights,
125-
layer_class=layer_class,
126-
reuse=reuse)
127-
if qcs:
128-
node.candidates_quantization_cfg = qcs
129-
return node
130-
35+
from tests_pytest.test_util.graph_builder_utils import build_node, build_qc, full_attr_name
13136

13237
BM = BitwidthMode
13338
TIC = TargetInclusionCriterion

tests_pytest/common/test_model_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import pytest
16-
from unittest.mock import Mock, call
16+
from unittest.mock import Mock
1717
import numpy as np
1818
from numpy.testing import assert_array_equal
1919

@@ -23,7 +23,7 @@
2323
from model_compression_toolkit.core.common.graph.edge import Edge
2424
from model_compression_toolkit.core.common.hessian import HessianInfoService
2525
from model_compression_toolkit.core.common.model_collector import create_stats_collector_for_node, create_tensor2node, ModelCollector
26-
from tests_pytest.common.graph_builder_utils import build_node, DummyLayer, build_qc
26+
from tests_pytest.test_util.graph_builder_utils import build_node, DummyLayer, build_qc
2727

2828

2929
@pytest.fixture

tests_pytest/keras/core/mixed_precision/test_resource_utilization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from keras.layers import Conv2D, Conv2DTranspose, DepthwiseConv2D, Dense, Input, Subtract, Flatten
2020

2121
from tests_pytest.base_test_classes.base_test_ru_integration import BaseRUIntegrationTester
22-
from tests_pytest.keras.keras_test_mixin import KerasFwMixin
22+
from tests_pytest.keras.keras_test_util.keras_test_mixin import KerasFwMixin
2323

2424

2525
class TestRUIntegrationKeras(BaseRUIntegrationTester, KerasFwMixin):
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
File renamed without changes.

tests_pytest/pytorch/core/mixed_precision/test_resource_utilization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
1919
from tests_pytest.base_test_classes.base_test_ru_integration import BaseRUIntegrationTester
20-
from tests_pytest.pytorch.torch_test_mixin import TorchFwMixin
20+
from tests_pytest.pytorch.torch_test_util.torch_test_mixin import TorchFwMixin
2121

2222

2323
class TestRUIntegrationTorch(BaseRUIntegrationTester, TorchFwMixin):
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)