Skip to content

Commit 24f5dca

Browse files
feat(metrics): enhance DinoScore model validation and error handling
- Added a method to return all valid model keys for better validation. - Improved error handling to raise a clear ValueError for unrecognized model keys. - Updated model parameter documentation for clarity and completeness. - Refactored model loading logic to streamline the process for both TIMM and Hugging Face models. - Added a new test to ensure invalid model keys are properly handled.
1 parent ab1f65a commit 24f5dca

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

src/pruna/evaluation/metrics/metric_dino_score.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,18 @@ class DinoScore(StatefulMetric):
5050
reference images.
5151
5252
DINO v1 and DINOv2 load via timm. DINOv3 loads via Hugging Face Transformers (>=4.56.0).
53-
See https://github.com/facebookresearch/dinov3 and
54-
https://huggingface.co/collections/facebook/dinov3 for available models.
5553
5654
Parameters
5755
----------
5856
device : str | torch.device | None
5957
The device to use for the metric.
6058
model : str
61-
Backbone name. "dino" (default), "dinov2_vits14", "dinov2_vitb14",
62-
"dinov2_vitl14", "dinov3_vits16", "dinov3_vits16plus", "dinov3_vitb16",
63-
"dinov3_vitl16", "dinov3_vith16plus", "dinov3_vit7b16",
64-
"dinov3_convnext_tiny/small/base/large", "dinov3_vitl16_sat493m",
65-
"dinov3_vit7b16_sat493m", etc. DINOv3 uses HF Transformers; DINO v1/v2
66-
use timm. Any timm or HF model ID also accepted.
59+
One of the registered model keys. TIMM_MODELS keys: "dino", "dinov2_vits14",
60+
"dinov2_vitb14", "dinov2_vitl14". HF_DINOV3_MODELS keys: "dinov3_vits16",
61+
"dinov3_vits16plus", "dinov3_vitb16", "dinov3_vitl16", "dinov3_vith16plus",
62+
"dinov3_vit7b16", "dinov3_convnext_tiny", "dinov3_convnext_small",
63+
"dinov3_convnext_base", "dinov3_convnext_large", "dinov3_vitl16_sat493m",
64+
"dinov3_vit7b16_sat493m".
6765
call_type : str
6866
The call type to use for the metric.
6967
@@ -96,6 +94,11 @@ class DinoScore(StatefulMetric):
9694
"dinov3_vit7b16_sat493m": "facebook/dinov3-vit7b16-pretrain-sat493m",
9795
}
9896

97+
@classmethod
98+
def valid_models(cls) -> list[str]:
99+
"""Return all valid model keys."""
100+
return list(cls.TIMM_MODELS) + list(cls.HF_DINOV3_MODELS)
101+
99102
similarities: List[Tensor]
100103
metric_name: str = DINO_SCORE
101104
higher_is_better: bool = True
@@ -114,20 +117,19 @@ def __init__(
114117
pruna_logger.error(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}")
115118
raise
116119
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
117-
self.model_name = model
120+
valid = self.valid_models()
121+
if model not in valid:
122+
raise ValueError(f"Unknown DinoScore model '{model}'. Valid keys: {valid}")
118123

119-
hf_name = self.HF_DINOV3_MODELS.get(model)
120-
if hf_name is not None or (model.startswith("facebook/") and "dinov3" in model):
124+
if model in self.HF_DINOV3_MODELS:
121125
from transformers import AutoModel
122126

123-
hf_name = hf_name or model
124-
self.model = AutoModel.from_pretrained(hf_name)
127+
self.model = AutoModel.from_pretrained(self.HF_DINOV3_MODELS[model])
125128
self.model.eval().to(self.device)
126129
self._use_transformers = True
127130
h = 224
128131
else:
129-
timm_name = self.TIMM_MODELS.get(model, model)
130-
self.model = timm.create_model(timm_name, pretrained=True)
132+
self.model = timm.create_model(self.TIMM_MODELS[model], pretrained=True)
131133
self.model.eval().to(self.device)
132134
self._use_transformers = False
133135
h = self.model.default_cfg.get("input_size", (3, 224, 224))[1]
@@ -145,8 +147,9 @@ def _get_embeddings(self, x: Tensor) -> Tensor:
145147
if self._use_transformers:
146148
out = self.model(pixel_values=x)
147149
return out.pooler_output
148-
features = self.model.forward_features(x)
149-
return features["x_norm_clstoken"] if isinstance(features, dict) else features[:, 0]
150+
else:
151+
features = self.model.forward_features(x)
152+
return features["x_norm_clstoken"] if isinstance(features, dict) else features[:, 0]
150153

151154
@torch.no_grad()
152155
def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> None:

tests/evaluation/test_dino_score.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@
33
from pruna.evaluation.metrics.metric_dino_score import DinoScore
44

55
DINO_MODELS = [
6-
pytest.param("dino", id="dino_v1"),
7-
pytest.param("dinov2_vits14", id="dinov2_vits14", marks=pytest.mark.slow),
8-
pytest.param("dinov2_vitb14", id="dinov2_vitb14", marks=pytest.mark.slow),
9-
pytest.param("dinov2_vitl14", id="dinov2_vitl14", marks=pytest.mark.slow),
6+
"dino",
7+
pytest.param("dinov2_vits14", marks=pytest.mark.slow),
8+
pytest.param("dinov2_vitb14", marks=pytest.mark.slow),
9+
pytest.param("dinov2_vitl14", marks=pytest.mark.slow),
1010
pytest.param(
1111
"dinov3_vits16",
12-
id="dinov3_vits16",
1312
marks=[
1413
pytest.mark.slow,
1514
pytest.mark.skip(reason="DINOv3 HF models are gated; requires access approval"),
1615
],
1716
),
1817
pytest.param(
1918
"dinov3_convnext_tiny",
20-
id="dinov3_convnext_tiny",
2119
marks=[
2220
pytest.mark.slow,
2321
pytest.mark.skip(reason="DINOv3 HF models are gated; requires access approval"),
@@ -40,6 +38,12 @@ def test_dino_score_models(model: str):
4038
assert -1.0 - 1e-5 <= result.result <= 1.0 + 1e-5
4139

4240

41+
def test_dino_score_invalid_model():
42+
"""Test that an unrecognised model key raises a clear ValueError."""
43+
with pytest.raises(ValueError, match="Unknown DinoScore model"):
44+
DinoScore(device="cpu", model="facebook/dinov3-irrelevant-wrong-model")
45+
46+
4347
def test_dino_score():
4448
"""Test the DinoScore metric with default model (backward compatibility)."""
4549
metric = DinoScore(device="cpu")

0 commit comments

Comments
 (0)