add support for custom metric function for mixed precision#1420
add support for custom metric function for mixed precision#1420Idan-BenAmi merged 6 commits intoSonySemiconductorSolutions:mainfrom
Conversation
|
|
||
| compute_distance_fn: Optional[Callable] = None | ||
| distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG | ||
| custom_metric_fn: Optional[Callable] = None |
There was a problem hiding this comment.
Please add description to docstring, including the expected api of the function (args it accepts and what it should return).
| return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg), | ||
| node_idx, | ||
| topo_cfg(baseline_cfg) if baseline_cfg else None) | ||
| if self.sensitivity_evaluator.quant_config.custom_metric_fn is None: |
There was a problem hiding this comment.
I think it's better to keep a single entry point to sensitivity evaluator and keep mp manager agnostic to this, i.e. call self.sensitivity_evaluator.compute_metric and let it decide what to do. No reason to spread the logic between two places.
| return self._compute_mp_distance_measure(ipts_distances, out_pts_distances, | ||
| self.quant_config.distance_weighting_method) | ||
|
|
||
| def compute_custom_metric(self, |
There was a problem hiding this comment.
Consider uniting the two methods. configure -> compute default or custom metric -> configure back, instead of replicating configuration in two places. You can move line 193 before 189, so there shouldn't be a problem.
| from model_compression_toolkit.core import MixedPrecisionQuantizationConfig | ||
| from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation | ||
| from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode | ||
| from model_compression_toolkit.core.pytorch.utils import to_torch_tensor |
There was a problem hiding this comment.
Why do you need the dependency on torch? I would expect this test to be under common, without any framework dependencies
|
|
||
| @pytest.fixture | ||
| def sensitivity_evaluator_factory(): | ||
| def _create_sensitivity_evaluator(custom_metric_fn): |
There was a problem hiding this comment.
What's the benefit of having this as a fixture? It complicates the definition, you need to pass it to test as a fixture, and then you call it anyway inside the test. You could just define it directly (or combine with get_sensitivity_evaluator) and call from test as a regular function/method.
357b42b
into
SonySemiconductorSolutions:main
Pull Request Description:
Add support for custom metric function for mixed precision.
The function will get the model_mp as input and return a metric score.
model_mp will return all the model's outputs and not the interest points.
Checklist before requesting a review: