2020from model_compression_toolkit .core import MixedPrecisionQuantizationConfig
2121from model_compression_toolkit .core .common .mixed_precision .sensitivity_evaluation import SensitivityEvaluation
2222from model_compression_toolkit .core .common .model_builder_mode import ModelBuilderMode
23- from model_compression_toolkit .core .pytorch .utils import to_torch_tensor
2423
2524
2625def custom_float_metric (model_mp ) -> float :
@@ -39,20 +38,7 @@ def custom_none_metric(model_mp):
3938 return None
4039
4140
42- @pytest .fixture
43- def mock_model_configuration ():
44- return Mock ()
45-
46-
47- @pytest .fixture
48- def sensitivity_evaluator_factory ():
49- def _create_sensitivity_evaluator (custom_metric_fn ):
50- mp_cfg = MixedPrecisionQuantizationConfig (custom_metric_fn = custom_metric_fn )
51- return get_sensitivity_evaluator (mp_cfg = mp_cfg )
52- return _create_sensitivity_evaluator
53-
54-
55- def get_sensitivity_evaluator (mp_cfg ):
41+ def get_sensitivity_evaluator (custom_metric_fn ):
5642 mock_graph = Mock ()
5743 mock_graph .get_topo_sorted_nodes .return_value = ['test' , 'this' , 'is' , 'reset' ]
5844 mock_graph .get_outputs .return_value = []
@@ -71,14 +57,16 @@ def custom_model_builder_return_value(*args, **kwargs):
7157 return (None , None , None )
7258
7359 def custom_to_tensor (img ):
74- return to_torch_tensor ( img )
60+ return img
7561
7662 mock_fw_impl = Mock ()
7763 mock_fw_impl .model_builder .side_effect = custom_model_builder_return_value
7864 mock_fw_impl .to_tensor .side_effect = custom_to_tensor
7965
8066 mock_set_layer_to_bitwidth = Mock ()
8167
68+ mp_cfg = MixedPrecisionQuantizationConfig (custom_metric_fn = custom_metric_fn )
69+
8270 sensitivity_eval = SensitivityEvaluation (graph = mock_graph ,
8371 quant_config = mp_cfg ,
8472 representative_data_gen = representative_data_gen ,
@@ -97,17 +85,17 @@ class TestMPCustomMetricFunction:
9785 (custom_float_metric , 100.0 ),
9886 (custom_np_float_metric , np .float64 (100.0 )),
9987 ])
100- def test_valid_metric_function (self , sensitivity_evaluator_factory , mock_model_configuration , metric_fn , expected ):
101- sensitivity_eval = sensitivity_evaluator_factory (metric_fn )
88+ def test_valid_metric_function (self , metric_fn , expected ):
89+ sensitivity_eval = get_sensitivity_evaluator (metric_fn )
10290 assert len (sensitivity_eval .interest_points ) == 0
103- assert sensitivity_eval .compute_metric (mock_model_configuration ) == expected
91+ assert sensitivity_eval .compute_metric (Mock () ) == expected
10492
10593 @pytest .mark .parametrize ("metric_fn, expected" , [
10694 (custom_str_metric , str .__name__ ),
10795 (custom_none_metric , type (None ).__name__ ),
10896 ])
109- def test_type_invalid_metric_function (self , sensitivity_evaluator_factory , mock_model_configuration , metric_fn , expected ):
110- sensitivity_eval = sensitivity_evaluator_factory (metric_fn )
97+ def test_type_invalid_metric_function (self , metric_fn , expected ):
98+ sensitivity_eval = get_sensitivity_evaluator (metric_fn )
11199 assert len (sensitivity_eval .interest_points ) == 0
112100 with pytest .raises (TypeError , match = f'The custom_metric_fn is expected to return float or numpy float, got { expected } ' ):
113- sensitivity_metric = sensitivity_eval .compute_metric (mock_model_configuration )
101+ sensitivity_metric = sensitivity_eval .compute_metric (Mock () )
0 commit comments