Skip to content

add support for custom metric function for mixed precision#1420

Merged
Idan-BenAmi merged 6 commits intoSonySemiconductorSolutions:mainfrom
itai-berman:task_loss_api
Apr 28, 2025
Merged

add support for custom metric function for mixed precision#1420
Idan-BenAmi merged 6 commits intoSonySemiconductorSolutions:mainfrom
itai-berman:task_loss_api

Conversation

@itai-berman
Copy link
Copy Markdown
Contributor

@itai-berman itai-berman commented Apr 22, 2025

Pull Request Description:

Add support for custom metric function for mixed precision.
The function will get the model_mp as input and return a metric score.
model_mp will return all the model's outputs and not the interest points.

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).


compute_distance_fn: Optional[Callable] = None
distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
custom_metric_fn: Optional[Callable] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add description to docstring, including the expected api of the function (args it accepts and what it should return).

return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
node_idx,
topo_cfg(baseline_cfg) if baseline_cfg else None)
if self.sensitivity_evaluator.quant_config.custom_metric_fn is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to keep a single entry point to sensitivity evaluator and keep mp manager agnostic to this, i.e. call self.sensitivity_evaluator.compute_metric and let it decide what to do. No reason to spread the logic between two places.

return self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
self.quant_config.distance_weighting_method)

def compute_custom_metric(self,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider uniting the two methods. configure -> compute default or custom metric -> configure back, instead of replicating configuration in two places. You can move line 193 before 189, so there shouldn't be a problem.

from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need the dependency on torch? I would expect this test to be under common, without any framework dependencies


@pytest.fixture
def sensitivity_evaluator_factory():
def _create_sensitivity_evaluator(custom_metric_fn):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of having this as a fixture? It complicates the definition, you need to pass it to test as a fixture, and then you call it anyway inside the test. You could just define it directly (or combine with get_sensitivity_evaluator) and call from test as a regular function/method.

@Idan-BenAmi Idan-BenAmi merged commit 357b42b into SonySemiconductorSolutions:main Apr 28, 2025
28 checks passed
@itai-berman itai-berman deleted the task_loss_api branch April 28, 2025 11:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants