Skip to content

Commit ad311a5

Browse files
authored
Refactor fw_info handling in pytest unittests (#1488)
* Switch pytest unittests to fw_info patch instead of changing the global _current_framework_info state. * Remove activation_quantizer_mapping from DummyFrameworkInfo in tests.
1 parent 5ecc21f commit ad311a5

11 files changed

Lines changed: 96 additions & 134 deletions

tests_pytest/common_tests/unit_tests/core/graph/test_base_node.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
import numpy as np
1616

1717
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc
18-
from model_compression_toolkit.core.common.framework_info import set_fw_info
1918

2019

21-
def test_find_min_max_candidate_index(fw_info_mock):
22-
set_fw_info(fw_info_mock)
20+
def test_find_min_max_candidate_index(patch_fw_info):
2321
qcs = []
2422
for ab in [4, 8, 16, 2]:
2523
for fb in [2, 8, 4]:

tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from model_compression_toolkit.quantization_preparation.load_fqc import set_quantization_configs_to_node
2323
from tests_pytest._test_util.graph_builder_utils import build_node, DummyLayer
24-
from model_compression_toolkit.core.common.framework_info import set_fw_info
2524
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
2625
OpQuantizationConfig, AttributeQuantizationConfig, Signedness
2726
from mct_quantizers import QuantizationMethod
@@ -36,10 +35,6 @@ class NoActivationQuantNode:
3635

3736

3837
class TestSetNodeQuantizationConfig:
39-
@pytest.fixture(autouse=True)
40-
def setup(self, fw_info_mock):
41-
set_fw_info(fw_info_mock)
42-
4338
@staticmethod
4439
def _get_op_config(activation_n_bits,
4540
supported_input_activation_n_bits,
@@ -55,10 +50,9 @@ def _get_op_config(activation_n_bits,
5550
quantization_preserving=quantization_preserving,
5651
signedness=Signedness.AUTO)
5752

58-
def test_activation_preserving_with_2_inputs(self, fw_info_mock):
53+
def test_activation_preserving_with_2_inputs(self, patch_fw_info):
5954
""" Tests that . """
60-
fw_info_mock.activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: lambda x: 0}
61-
fw_info_mock.get_kernel_op_attribute = lambda x: None
55+
patch_fw_info.get_kernel_op_attribute = lambda x: None
6256

6357
n1 = build_node('in1_node')
6458
n2 = build_node('in2_node')
@@ -71,9 +65,9 @@ def test_activation_preserving_with_2_inputs(self, fw_info_mock):
7165
Edge(n3, n4, 0, 0),
7266
Edge(n1, qp3, 0, 0), Edge(qp3, qp4, 0, 0)])
7367
q_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7,
74-
"enable_activation_quantization": True, "quantization_preserving": False}
68+
"enable_activation_quantization": True, "quantization_preserving": False}
7569
qp_op_config_kwargs = {"activation_n_bits": 7, "supported_input_activation_n_bits": 7,
76-
"enable_activation_quantization": False, "quantization_preserving": True}
70+
"enable_activation_quantization": False, "quantization_preserving": True}
7771
_filters = {DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**q_op_config_kwargs)]),
7872
PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**qp_op_config_kwargs)])}
7973
fqc = Mock(filterlayer2qco=_filters, layer2qco=_filters)
@@ -85,12 +79,11 @@ def test_activation_preserving_with_2_inputs(self, fw_info_mock):
8579
assert qp3.is_quantization_preserving()
8680
assert qp4.is_quantization_preserving()
8781

88-
def test_node_quantization_by_next_nodes(self, fw_info_mock):
82+
def test_node_quantization_by_next_nodes(self, patch_fw_info):
8983
"""
9084
Test that node quantization n_bits is unaffected by preserving next node and not-enabled quantization next node.
9185
"""
92-
fw_info_mock.activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: lambda x: 0}
93-
fw_info_mock.get_kernel_op_attribute = lambda x: None
86+
patch_fw_info.get_kernel_op_attribute = lambda x: None
9487

9588
first_node = build_node('first_node')
9689
preserving_node = build_node('preserving_node', layer_class=PreservingNode)
@@ -105,14 +98,14 @@ def test_node_quantization_by_next_nodes(self, fw_info_mock):
10598
"quantization_preserving": False}
10699

107100
preserving_node_config_kwargs = {"activation_n_bits": 8,
108-
"supported_input_activation_n_bits": [8, 16],
109-
"enable_activation_quantization": False,
110-
"quantization_preserving": True}
101+
"supported_input_activation_n_bits": [8, 16],
102+
"enable_activation_quantization": False,
103+
"quantization_preserving": True}
111104

112105
no_quant_node_config_kwargs = {"activation_n_bits": 8,
113-
"supported_input_activation_n_bits": [8],
114-
"enable_activation_quantization": False,
115-
"quantization_preserving": False}
106+
"supported_input_activation_n_bits": [8],
107+
"enable_activation_quantization": False,
108+
"quantization_preserving": False}
116109
_filters = {
117110
DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**first_node_config_kwargs)]),
118111
PreservingNode: QuantizationConfigOptions(quantization_configurations=[self._get_op_config(**preserving_node_config_kwargs)]),

tests_pytest/common_tests/unit_tests/core/graph/test_quantization_preserving_node.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,12 @@
1616

1717
from model_compression_toolkit.core.common import Graph
1818
from model_compression_toolkit.core.common.graph.edge import Edge
19-
from model_compression_toolkit.core.common.framework_info import set_fw_info
2019

2120
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc
2221

2322

2423
class TestQuantizationPreservingNode:
25-
@pytest.fixture(autouse=True)
26-
def setup(self, fw_info_mock):
27-
set_fw_info(fw_info_mock)
28-
29-
def test_activation_preserving_candidate(self):
24+
def test_activation_preserving_candidate(self, patch_fw_info):
3025
""" Tests that the correct activation quantization candidate is selected. """
3126
n1 = build_node('qact_node', qcs=[build_nbits_qc()])
3227
n2 = build_node('qp1a_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
@@ -42,7 +37,7 @@ def test_activation_preserving_candidate(self):
4237
assert graph.retrieve_preserved_quantization_node(n4) is n4
4338
assert graph.retrieve_preserved_quantization_node(n5) is n4
4439

45-
def test_activation_preserving_disable_for_multi_input_node(self):
40+
def test_activation_preserving_disable_for_multi_input_node(self, patch_fw_info):
4641
""" Tests that the retrieve_preserved_quantization_node raises an assertion error if node has more than 1 input. """
4742
n1 = build_node('qact_node', qcs=[build_nbits_qc()])
4843
n2 = build_node('qp1a_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])

tests_pytest/common_tests/unit_tests/core/graph/test_virtual_activation_weights_node.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import numpy as np
1616
import pytest
1717

18-
from model_compression_toolkit.core.common.framework_info import set_fw_info
1918
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode
2019
from tests_pytest._test_util.graph_builder_utils import build_node, DummyLayer, build_nbits_qc
2120

@@ -25,17 +24,13 @@ class DummyLayerWKernel:
2524

2625

2726
class TestVirtualActivationWeightsNode:
28-
@pytest.fixture(autouse=True)
29-
def setup(self, fw_info_mock):
30-
set_fw_info(fw_info_mock)
31-
3227
# TODO tests only cover combining weights from activation and weight nodes and errors.
33-
def test_activation_with_weights(self, fw_info_mock):
28+
def test_activation_with_weights(self, patch_fw_info):
3429
""" Tests that weights from activation and weight node are combined correctly. """
3530
# Each node has a unique weight attr and a unique positional weights. In addition, both nodes have
3631
# an identical canonical attribute (but different full name), and an identical positional weight.
3732
# All weights have different quantization.
38-
fw_info_mock.get_kernel_op_attribute = lambda nt: 'weight' if nt is DummyLayerWKernel else None
33+
patch_fw_info.get_kernel_op_attribute = lambda nt: 'weight' if nt is DummyLayerWKernel else None
3934

4035
a_node = build_node('a', final_weights={'aaweightaa': np.ones((3, 14)), 'foo': np.ones(15),
4136
1: np.ones(15), 2: np.ones((5, 9))}, qcs=[build_nbits_qc(a_nbits=5,
@@ -101,8 +96,8 @@ def test_activation_with_weights(self, fw_info_mock):
10196
3: w_qc.weights_quantization_cfg.pos_attributes_config_mapping[3]
10297
}
10398

104-
def test_invalid_configurable_w_node_weight(self, fw_info_mock):
105-
fw_info_mock.get_kernel_op_attribute = lambda nt: 'kernel' if nt is DummyLayerWKernel else None
99+
def test_invalid_configurable_w_node_weight(self, patch_fw_info):
100+
patch_fw_info.get_kernel_op_attribute = lambda nt: 'kernel' if nt is DummyLayerWKernel else None
106101

107102
w_node = build_node('w', canonical_weights={'kernel': np.ones(3), 'foo': np.ones(14)}, qcs=[
108103
build_nbits_qc(w_attr={'kernel': (8, True), 'foo': (8, True)}),
@@ -113,8 +108,8 @@ def test_invalid_configurable_w_node_weight(self, fw_info_mock):
113108
with pytest.raises(NotImplementedError, match='Only kernel weight can be configurable. Got configurable .*foo'):
114109
VirtualActivationWeightsNode(a_node, w_node)
115110

116-
def test_invalid_a_node_configurable_weight(self, fw_info_mock):
117-
fw_info_mock.get_kernel_op_attribute = lambda nt: 'kernel' if nt is DummyLayerWKernel else None
111+
def test_invalid_a_node_configurable_weight(self, patch_fw_info):
112+
patch_fw_info.get_kernel_op_attribute = lambda nt: 'kernel' if nt is DummyLayerWKernel else None
118113

119114
w_node = build_node('w', canonical_weights={'kernel': np.ones(3), 'foo': np.ones(14)}, qcs=[
120115
build_nbits_qc(w_attr={'kernel': (8, True), 'foo': (8, True)}),
@@ -128,8 +123,8 @@ def test_invalid_a_node_configurable_weight(self, fw_info_mock):
128123
'activation for VirtualActivationWeightsNode'):
129124
VirtualActivationWeightsNode(a_node, w_node)
130125

131-
def test_invalid_a_node_kernel(self, fw_info_mock):
132-
fw_info_mock.get_kernel_op_attribute = lambda nt: 'weight' if nt is DummyLayerWKernel else 'kernel'
126+
def test_invalid_a_node_kernel(self, patch_fw_info):
127+
patch_fw_info.get_kernel_op_attribute = lambda nt: 'weight' if nt is DummyLayerWKernel else 'kernel'
133128
w_node = build_node('w', canonical_weights={'weight': np.ones(3)},
134129
qcs=[build_nbits_qc(w_attr={'weight': (8, True)})], layer_class=DummyLayerWKernel)
135130
a_node = build_node('aaa', canonical_weights={'kernel': np.ones(3)},

tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from model_compression_toolkit.constants import FLOAT_BITWIDTH, FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
2424
from model_compression_toolkit.core import ResourceUtilization
25-
from model_compression_toolkit.core.common.framework_info import set_fw_info
2625
from model_compression_toolkit.core.common import Graph
2726
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
2827
from model_compression_toolkit.core.common.graph.edge import Edge
@@ -69,9 +68,8 @@ class TestComputeResourceUtilization:
6968
compute_resource_utilization on a virtual graph is tested in TestBOPSAndVirtualGraph
7069
"""
7170
@pytest.fixture(autouse=True)
72-
def setup(self, graph_mock, fw_impl_mock, fw_info_mock):
73-
fw_info_mock.get_kernel_op_attribute = Mock(return_value='foo') # for bops
74-
set_fw_info(fw_info_mock)
71+
def setup(self, graph_mock, fw_impl_mock, patch_fw_info):
72+
patch_fw_info.get_kernel_op_attribute = Mock(return_value='foo') # for bops
7573
fw_impl_mock.get_node_mac_operations = lambda n: 42 if n == n2 else 0 # for bops
7674
n1 = build_node('n1', qcs=[build_qc()], output_shape=(None, 5, 10))
7775
n2 = build_node('n2', canonical_weights={'foo': np.zeros((3, 14))}, qcs=[build_qc(w_attr={'foo': (4, True)})],
@@ -226,8 +224,8 @@ def _validate(self, ret, detailed, exp_ru: ResourceUtilization):
226224

227225
class TestActivationUtilizationMethods:
228226
@pytest.fixture(autouse=True)
229-
def setup(self, fw_info_mock):
230-
set_fw_info(fw_info_mock)
227+
def setup(self, patch_fw_info):
228+
pass
231229

232230
""" Tests for non-public activation utilization api. """
233231
def test_get_a_nbits_configurable(self, graph_mock, fw_impl_mock):
@@ -337,8 +335,8 @@ def test_get_target_activation_nodes(self, graph_mock, fw_impl_mock):
337335

338336
class TestComputeActivationTensorsUtilization:
339337
@pytest.fixture(autouse=True)
340-
def setup(self, fw_info_mock):
341-
set_fw_info(fw_info_mock)
338+
def setup(self, patch_fw_info):
339+
pass
342340

343341
""" Tests for activation tensors utilization public apis. """
344342
def test_compute_node_activation_tensor_utilization(self, graph_mock, fw_impl_mock):
@@ -436,8 +434,8 @@ def test_compute_act_tensors_util_invalid_custom_qcs(self, graph_mock, fw_impl_m
436434

437435
class TestActivationMaxCutUtilization:
438436
@pytest.fixture(autouse=True)
439-
def setup(self, fw_info_mock):
440-
set_fw_info(fw_info_mock)
437+
def setup(self, patch_fw_info):
438+
pass
441439

442440
""" Tests for activation max cut utilization. """
443441
def test_compute_cuts_integration(self, graph_mock, fw_impl_mock, mocker):
@@ -737,8 +735,8 @@ def test_compute_act_utilization_by_cut_invalid_custom_qcs(self, graph_mock, fw_
737735
class TestWeightUtilizationMethods:
738736
""" Tests for weights utilization non-public api. """
739737
@pytest.fixture(autouse=True)
740-
def setup(self, fw_info_mock):
741-
set_fw_info(fw_info_mock)
738+
def setup(self, patch_fw_info):
739+
pass
742740

743741
def test_get_w_nbits(self, graph_mock, fw_impl_mock):
744742
ru_calc = ResourceUtilizationCalculator(graph_mock, fw_impl_mock)
@@ -850,8 +848,8 @@ def test_collect_target_nodes_w_attrs(self, graph_mock, fw_impl_mock):
850848
class TestComputeNodeWeightsUtilization:
851849
""" Tests for compute_node_weight_utilization public method. """
852850
@pytest.fixture(autouse=True)
853-
def setup(self, fw_info_mock):
854-
set_fw_info(fw_info_mock)
851+
def setup(self, patch_fw_info):
852+
pass
855853

856854
@pytest.fixture
857855
def setup_node_w_test(self, graph_mock, fw_impl_mock):
@@ -940,8 +938,8 @@ def test_compute_node_w_utilization_errors(self, graph_mock, fw_impl_mock, setup
940938
class TestComputeWeightUtilization:
941939
""" Tests for compute_weight_utilization public method. """
942940
@pytest.fixture(autouse=True)
943-
def setup(self, fw_info_mock):
944-
set_fw_info(fw_info_mock)
941+
def setup(self, patch_fw_info):
942+
pass
945943

946944
@pytest.fixture
947945
def prepare_compute_w_util(self, fw_impl_mock):
@@ -1050,8 +1048,8 @@ def test_compute_w_utilization_invalid_custom_qcs(self, graph_mock, fw_impl_mock
10501048
class TestCalculatorMisc:
10511049
""" Calculator tests that don't belong to other test classes """
10521050
@pytest.fixture(autouse=True)
1053-
def setup(self, fw_info_mock):
1054-
set_fw_info(fw_info_mock)
1051+
def setup(self, patch_fw_info):
1052+
pass
10551053

10561054
def test_calculator_init(self, fw_impl_mock):
10571055
n1 = build_node('n1', qcs=[build_qc(a_enable=False)], output_shape=(None, 5, 10))
@@ -1089,8 +1087,8 @@ class BOPNode:
10891087

10901088
class TestBOPSAndVirtualGraph:
10911089
@pytest.fixture(autouse=True)
1092-
def setup(self, fw_info_mock):
1093-
set_fw_info(fw_info_mock)
1090+
def setup(self, patch_fw_info):
1091+
pass
10941092

10951093
def test_compute_regular_node_bops(self, fw_impl_mock, fw_info_mock):
10961094
fw_info_mock.get_kernel_op_attribute = lambda node_type: 'foo' if node_type == BOPNode else None
@@ -1177,7 +1175,7 @@ def test_node_bops_invalid_criterion(self, graph_mock, fw_impl_mock, target_crit
11771175
with pytest.raises(ValueError, match='BOPS computation is supported only for Any, AnyQuantized and AnyQuantizedNonFused targets.'):
11781176
ru_calc.compute_node_bops(Mock(), target_criterion, BM.Float)
11791177

1180-
def test_compute_bops(self, fw_impl_mock, fw_info_mock,):
1178+
def test_compute_bops(self, fw_impl_mock, fw_info_mock):
11811179
class BOPNode2:
11821180
pass
11831181

tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
class TestResourceUtilizationData:
2828
@pytest.mark.parametrize('error_method', [QuantizationErrorMethod.MSE, QuantizationErrorMethod.HMSE])
29-
def test_resource_utilization_data(self, fw_info_mock, fw_impl_mock, error_method, mocker):
29+
def test_resource_utilization_data(self, fw_impl_mock, error_method, mocker):
3030
core_cfg = CoreConfig()
3131
core_cfg.quantization_config.weights_error_method = error_method
3232
core_cfg.bit_width_config = BitWidthConfig([1, 2])
@@ -42,7 +42,6 @@ def test_resource_utilization_data(self, fw_info_mock, fw_impl_mock, error_metho
4242
prep_runner = mocker.patch('model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.'
4343
'resource_utilization_data.graph_preparation_runner')
4444

45-
_current_framework_info = fw_info_mock
4645
compute_resource_utilization_data(model_mock,
4746
data_gen_mock,
4847
core_cfg,

tests_pytest/common_tests/unit_tests/core/mixed_precision/sensitivity_eval/test_distance_calculator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import numpy as np
1818
import pytest
1919

20-
from model_compression_toolkit.core.common.framework_info import set_fw_info
2120
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, MpDistanceWeighting
2221
from model_compression_toolkit.core.common.hessian import HessianInfoService
2322
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.metric_calculators import \
@@ -34,8 +33,7 @@ class TestDistanceWeighting:
3433
out_pts = np.array([[1, 2], [3, 4], [5, 6]])
3534

3635
@pytest.fixture
37-
def setup(self, mocker, graph_mock, fw_info_mock, fw_impl_mock):
38-
set_fw_info(fw_info_mock)
36+
def setup(self, mocker, graph_mock, patch_fw_info, fw_impl_mock):
3937
mocker.patch.object(DistanceMetricCalculator, 'get_mp_interest_points', return_value=[None, None])
4038
mocker.patch.object(DistanceMetricCalculator, 'get_output_nodes_for_metric', return_value=[None])
4139
mocker.patch.object(DistanceMetricCalculator, '_init_metric_points_lists', return_value=(None, None))

0 commit comments

Comments
 (0)