Skip to content

Commit 6b8a9f5

Browse files
committed
fix test to really be common
1 parent d5bf801 commit 6b8a9f5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,12 @@
1717

1818
from unittest.mock import Mock
1919
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc, DummyLayer
20+
from model_compression_toolkit.core import FrameworkInfo
2021
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configs_to_node
2122
from model_compression_toolkit.core import QuantizationConfig
2223
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
2324
OpQuantizationConfig, AttributeQuantizationConfig, Signedness
2425
from mct_quantizers import QuantizationMethod
25-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
26-
27-
28-
29-
from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities
3026

3127

3228
class TestSetNodeQuantizationConfig:
@@ -43,7 +39,7 @@ def _get_op_config():
4339
quantization_preserving=True,
4440
signedness=Signedness.AUTO)
4541

46-
def test_activation_preserving_with_2_inputs(self, fw_impl_mock):
42+
def test_activation_preserving_with_2_inputs(self, fw_info_mock):
4743
""" Tests that . """
4844
n1 = build_node('in1_node')
4945
n2 = build_node('in2_node')
@@ -55,7 +51,11 @@ def test_activation_preserving_with_2_inputs(self, fw_impl_mock):
5551

5652
fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])},
5753
layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])})
58-
set_quantization_configs_to_node(n3, graph, QuantizationConfig(), DEFAULT_PYTORCH_INFO, fqc)
59-
set_quantization_configs_to_node(n4, graph, QuantizationConfig(), DEFAULT_PYTORCH_INFO, fqc)
54+
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
55+
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
56+
get_kernel_op_attributes=lambda x: [None])
57+
set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc)
58+
set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc)
6059
assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled()
6160
assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled()
61+

0 commit comments

Comments
 (0)