Skip to content

Commit 0f9017c

Browse files
committed
move test to common and remove fixture
1 parent b70bb65 commit 0f9017c

File tree

2 files changed

+10
-36
lines changed

2 files changed

+10
-36
lines changed

tests_pytest/pytorch_tests/unit_tests/core/mixed_precision/test_custom_metric_function.py renamed to tests_pytest/common_tests/unit_tests/core/mixed_precision/test_custom_metric_function.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
2121
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
2222
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
23-
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
2423

2524

2625
def custom_float_metric(model_mp) -> float:
@@ -39,20 +38,7 @@ def custom_none_metric(model_mp):
3938
return None
4039

4140

42-
@pytest.fixture
43-
def mock_model_configuration():
44-
return Mock()
45-
46-
47-
@pytest.fixture
48-
def sensitivity_evaluator_factory():
49-
def _create_sensitivity_evaluator(custom_metric_fn):
50-
mp_cfg = MixedPrecisionQuantizationConfig(custom_metric_fn=custom_metric_fn)
51-
return get_sensitivity_evaluator(mp_cfg=mp_cfg)
52-
return _create_sensitivity_evaluator
53-
54-
55-
def get_sensitivity_evaluator(mp_cfg):
41+
def get_sensitivity_evaluator(custom_metric_fn):
5642
mock_graph = Mock()
5743
mock_graph.get_topo_sorted_nodes.return_value = ['test', 'this', 'is', 'reset']
5844
mock_graph.get_outputs.return_value = []
@@ -71,14 +57,16 @@ def custom_model_builder_return_value(*args, **kwargs):
7157
return (None, None, None)
7258

7359
def custom_to_tensor(img):
74-
return to_torch_tensor(img)
60+
return img
7561

7662
mock_fw_impl = Mock()
7763
mock_fw_impl.model_builder.side_effect = custom_model_builder_return_value
7864
mock_fw_impl.to_tensor.side_effect = custom_to_tensor
7965

8066
mock_set_layer_to_bitwidth = Mock()
8167

68+
mp_cfg = MixedPrecisionQuantizationConfig(custom_metric_fn=custom_metric_fn)
69+
8270
sensitivity_eval = SensitivityEvaluation(graph=mock_graph,
8371
quant_config=mp_cfg,
8472
representative_data_gen=representative_data_gen,
@@ -97,17 +85,17 @@ class TestMPCustomMetricFunction:
9785
(custom_float_metric, 100.0),
9886
(custom_np_float_metric, np.float64(100.0)),
9987
])
100-
def test_valid_metric_function(self, sensitivity_evaluator_factory, mock_model_configuration, metric_fn, expected):
101-
sensitivity_eval = sensitivity_evaluator_factory(metric_fn)
88+
def test_valid_metric_function(self, metric_fn, expected):
89+
sensitivity_eval = get_sensitivity_evaluator(metric_fn)
10290
assert len(sensitivity_eval.interest_points) == 0
103-
assert sensitivity_eval.compute_metric(mock_model_configuration) == expected
91+
assert sensitivity_eval.compute_metric(Mock()) == expected
10492

10593
@pytest.mark.parametrize("metric_fn, expected", [
10694
(custom_str_metric, str.__name__),
10795
(custom_none_metric, type(None).__name__),
10896
])
109-
def test_type_invalid_metric_function(self, sensitivity_evaluator_factory, mock_model_configuration, metric_fn, expected):
110-
sensitivity_eval = sensitivity_evaluator_factory(metric_fn)
97+
def test_type_invalid_metric_function(self, metric_fn, expected):
98+
sensitivity_eval = get_sensitivity_evaluator(metric_fn)
11199
assert len(sensitivity_eval.interest_points) == 0
112100
with pytest.raises(TypeError, match=f'The custom_metric_fn is expected to return float or numpy float, got {expected}'):
113-
sensitivity_metric = sensitivity_eval.compute_metric(mock_model_configuration)
101+
sensitivity_metric = sensitivity_eval.compute_metric(Mock())

tests_pytest/pytorch_tests/unit_tests/core/mixed_precision/__init__.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)