Skip to content

Commit d4f2e8b

Browse files
feat(metrics): DINO Score v3 (#568)
* feat(metrics): DINO Score v3 - CLS fix, HF models, multi-model support - Use [CLS] token embeddings for DINO v1/v2 (timm) - Add DINOv3 support via Hugging Face Transformers - Dynamic preprocessing based on model input size - Support dinov2_vits14/vitb14/vitl14, dinov3 variants, convnext Made-with: Cursor * refactor(metrics): reorganize references in DinoScore class - Moved references section within the DinoScore class docstring for better visibility. - Ensured all relevant DINO model links are included in the updated references section. * refactor(metrics): improve import organization in metric_dino_score.py - Reformatted import statements for better readability and consistency. - Grouped related imports together to enhance code structure. * feat(metrics): enhance DinoScore model validation and error handling - Added a method to return all valid model keys for better validation. - Improved error handling to raise a clear ValueError for unrecognized model keys. - Updated model parameter documentation for clarity and completeness. - Refactored model loading logic to streamline the process for both TIMM and Hugging Face models. - Added a new test to ensure invalid model keys are properly handled.
1 parent 637eaff commit d4f2e8b

File tree

2 files changed

+128
-33
lines changed

2 files changed

+128
-33
lines changed

src/pruna/evaluation/metrics/metric_dino_score.py

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
2929
from pruna.evaluation.metrics.registry import MetricRegistry
3030
from pruna.evaluation.metrics.result import MetricResult
31-
from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor
31+
from pruna.evaluation.metrics.utils import (
32+
SINGLE,
33+
get_call_type_for_single_metric,
34+
metric_data_processor,
35+
)
3236
from pruna.logging.logger import pruna_logger
3337

3438
DINO_SCORE = "dino_score"
@@ -41,50 +45,113 @@ class DinoScore(StatefulMetric):
4145
4246
A similarity metric based on DINO (self-distillation with no labels),
4347
a self-supervised vision transformer trained to learn high-level image representations without annotations.
44-
DinoScore compares the embeddings of generated and reference images in this representation space,
48+
DinoScore compares the [CLS] token embeddings of generated and reference images in this representation space,
4549
producing a value where higher scores indicate that the generated images preserve more of the semantic content of the
4650
reference images.
4751
48-
Reference
49-
----------
50-
https://github.com/facebookresearch/dino
51-
https://arxiv.org/abs/2104.14294
52+
DINO v1 and DINOv2 load via timm. DINOv3 loads via Hugging Face Transformers (>=4.56.0).
5253
5354
Parameters
5455
----------
5556
device : str | torch.device | None
5657
The device to use for the metric.
58+
model : str
59+
One of the registered model keys. TIMM_MODELS keys: "dino", "dinov2_vits14",
60+
"dinov2_vitb14", "dinov2_vitl14". HF_DINOV3_MODELS keys: "dinov3_vits16",
61+
"dinov3_vits16plus", "dinov3_vitb16", "dinov3_vitl16", "dinov3_vith16plus",
62+
"dinov3_vit7b16", "dinov3_convnext_tiny", "dinov3_convnext_small",
63+
"dinov3_convnext_base", "dinov3_convnext_large", "dinov3_vitl16_sat493m",
64+
"dinov3_vit7b16_sat493m".
5765
call_type : str
5866
The call type to use for the metric.
67+
68+
References
69+
----------
70+
DINO: https://github.com/facebookresearch/dino, https://arxiv.org/abs/2104.14294
71+
DINOv2: https://github.com/facebookresearch/dinov2
72+
DINOv3: https://github.com/facebookresearch/dinov3
5973
"""
6074

75+
TIMM_MODELS: dict[str, str] = {
76+
"dino": "vit_small_patch16_224.dino",
77+
"dinov2_vits14": "vit_small_patch14_dinov2.lvd142m",
78+
"dinov2_vitb14": "vit_base_patch14_dinov2.lvd142m",
79+
"dinov2_vitl14": "vit_large_patch14_dinov2.lvd142m",
80+
}
81+
82+
HF_DINOV3_MODELS: dict[str, str] = {
83+
"dinov3_vits16": "facebook/dinov3-vits16-pretrain-lvd1689m",
84+
"dinov3_vits16plus": "facebook/dinov3-vits16plus-pretrain-lvd1689m",
85+
"dinov3_vitb16": "facebook/dinov3-vitb16-pretrain-lvd1689m",
86+
"dinov3_vitl16": "facebook/dinov3-vitl16-pretrain-lvd1689m",
87+
"dinov3_vith16plus": "facebook/dinov3-vith16plus-pretrain-lvd1689m",
88+
"dinov3_vit7b16": "facebook/dinov3-vit7b16-pretrain-lvd1689m",
89+
"dinov3_convnext_tiny": "facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
90+
"dinov3_convnext_small": "facebook/dinov3-convnext-small-pretrain-lvd1689m",
91+
"dinov3_convnext_base": "facebook/dinov3-convnext-base-pretrain-lvd1689m",
92+
"dinov3_convnext_large": "facebook/dinov3-convnext-large-pretrain-lvd1689m",
93+
"dinov3_vitl16_sat493m": "facebook/dinov3-vitl16-pretrain-sat493m",
94+
"dinov3_vit7b16_sat493m": "facebook/dinov3-vit7b16-pretrain-sat493m",
95+
}
96+
97+
@classmethod
98+
def valid_models(cls) -> list[str]:
99+
"""Return all valid model keys."""
100+
return list(cls.TIMM_MODELS) + list(cls.HF_DINOV3_MODELS)
101+
61102
similarities: List[Tensor]
62103
metric_name: str = DINO_SCORE
63104
higher_is_better: bool = True
64105
runs_on: List[str] = ["cuda", "cpu"]
65106
default_call_type: str = "gt_y"
66107

67-
def __init__(self, device: str | torch.device | None = None, call_type: str = SINGLE):
68-
super().__init__()
108+
def __init__(
109+
self,
110+
device: str | torch.device | None = None,
111+
model: str = "dino",
112+
call_type: str = SINGLE,
113+
):
114+
super().__init__(device=device)
69115
self.device = set_to_best_available_device(device)
70116
if device is not None and not any(self.device.startswith(prefix) for prefix in self.runs_on):
71117
msg = f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}"
72118
pruna_logger.error(msg)
73119
raise ValueError(msg)
74120
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
75-
# Load the DINO ViT-S/16 model once
76-
self.model = timm.create_model("vit_small_patch16_224.dino", pretrained=True)
77-
self.model.eval().to(self.device)
78-
# Add internal state to accumulate similarities
121+
valid = self.valid_models()
122+
if model not in valid:
123+
raise ValueError(f"Unknown DinoScore model '{model}'. Valid keys: {valid}")
124+
125+
if model in self.HF_DINOV3_MODELS:
126+
from transformers import AutoModel
127+
128+
self.model = AutoModel.from_pretrained(self.HF_DINOV3_MODELS[model])
129+
self.model.eval().to(self.device)
130+
self._use_transformers = True
131+
h = 224
132+
else:
133+
self.model = timm.create_model(self.TIMM_MODELS[model], pretrained=True)
134+
self.model.eval().to(self.device)
135+
self._use_transformers = False
136+
h = self.model.default_cfg.get("input_size", (3, 224, 224))[1]
137+
79138
self.add_state("similarities", default=[])
80139
self.processor = transforms.Compose(
81140
[
82-
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
83-
transforms.CenterCrop(224),
141+
transforms.Resize(int(h * 256 / 224), interpolation=transforms.InterpolationMode.BICUBIC),
142+
transforms.CenterCrop(h),
84143
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
85144
]
86145
)
87146

147+
def _get_embeddings(self, x: Tensor) -> Tensor:
148+
if self._use_transformers:
149+
out = self.model(pixel_values=x)
150+
return out.pooler_output
151+
else:
152+
features = self.model.forward_features(x)
153+
return features["x_norm_clstoken"] if isinstance(features, dict) else features[:, 0]
154+
88155
@torch.no_grad()
89156
def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> None:
90157
"""
@@ -103,15 +170,10 @@ def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> No
103170
inputs, preds = metric_inputs
104171
inputs = self.processor(inputs)
105172
preds = self.processor(preds)
106-
# Extract embeddings ([CLS] token)
107-
emb_x = self.model.forward_features(inputs)
108-
emb_y = self.model.forward_features(preds)
109-
110-
# Normalize embeddings
173+
emb_x = self._get_embeddings(inputs)
174+
emb_y = self._get_embeddings(preds)
111175
emb_x = F.normalize(emb_x, dim=1)
112176
emb_y = F.normalize(emb_y, dim=1)
113-
114-
# Compute cosine similarity
115177
sim = (emb_x * emb_y).sum(dim=1)
116178
self.similarities.append(sim)
117179

tests/evaluation/test_dino_score.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,55 @@
22
import pytest
33
from pruna.evaluation.metrics.metric_dino_score import DinoScore
44

5-
def test_dino_score():
6-
"""Test the DinoScore metric."""
7-
# Use CPU for testing
8-
metric = DinoScore(device="cpu")
5+
DINO_MODELS = [
6+
"dino",
7+
pytest.param("dinov2_vits14", marks=pytest.mark.slow),
8+
pytest.param("dinov2_vitb14", marks=pytest.mark.slow),
9+
pytest.param("dinov2_vitl14", marks=pytest.mark.slow),
10+
pytest.param(
11+
"dinov3_vits16",
12+
marks=[
13+
pytest.mark.slow,
14+
pytest.mark.skip(reason="DINOv3 HF models are gated; requires access approval"),
15+
],
16+
),
17+
pytest.param(
18+
"dinov3_convnext_tiny",
19+
marks=[
20+
pytest.mark.slow,
21+
pytest.mark.skip(reason="DINOv3 HF models are gated; requires access approval"),
22+
],
23+
),
24+
]
925

10-
# Create dummy images (batch of 2 images, 3x224x224)
11-
x = torch.rand(2, 3, 224, 224)
12-
y = torch.rand(2,3, 224, 224)
1326

14-
# Update metric
27+
@pytest.mark.cpu
28+
@pytest.mark.parametrize("model", DINO_MODELS)
29+
def test_dino_score_models(model: str):
30+
"""Test DinoScore with each supported backbone (dino, dinov2, dinov3)."""
31+
metric = DinoScore(device="cpu", model=model)
32+
x = torch.rand(2, 3, 224, 224)
33+
y = torch.rand(2, 3, 224, 224)
1534
metric.update(x, y, y)
16-
17-
# Compute result
1835
result = metric.compute()
36+
assert result.name == "dino_score"
37+
assert isinstance(result.result, float)
38+
assert -1.0 - 1e-5 <= result.result <= 1.0 + 1e-5
39+
1940

41+
def test_dino_score_invalid_model():
42+
"""Test that an unrecognised model key raises a clear ValueError."""
43+
with pytest.raises(ValueError, match="Unknown DinoScore model"):
44+
DinoScore(device="cpu", model="facebook/dinov3-irrelevant-wrong-model")
45+
46+
47+
def test_dino_score():
48+
"""Test the DinoScore metric with default model (backward compatibility)."""
49+
metric = DinoScore(device="cpu")
50+
x = torch.rand(2, 3, 224, 224)
51+
y = torch.rand(2, 3, 224, 224)
52+
metric.update(x, y, y)
53+
result = metric.compute()
2054
assert result.name == "dino_score"
2155
assert isinstance(result.result, float)
22-
# Cosine similarity should be between -1 and 1
23-
assert -1.0 <= result.result <= 1.0
56+
assert -1.0 - 1e-5 <= result.result <= 1.0 + 1e-5

0 commit comments

Comments
 (0)