Skip to content

Commit 4b9949a

Browse files
feat(metrics): DINO Score CLS fix, HF models, multi-model support
- Fix CLS token extraction ([:,0] for v1/v3, x_norm_clstoken for v2) - Add DINOv3 via Hugging Face (facebook/dinov3-*) - Add DINOv2 (torch.hub), DINO v1 (timm) - Parametrized tests for each model Made-with: Cursor
1 parent 209e12e commit 4b9949a

File tree

2 files changed

+143
-37
lines changed

2 files changed

+143
-37
lines changed

src/pruna/evaluation/metrics/metric_dino_score.py

Lines changed: 102 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any, List
17+
from typing import Any, List, Literal
1818

19-
import timm
2019
import torch
2120

2221
# Ruff complains when we don't import functional as f, but common practice is to import it as F
2322
import torch.nn.functional as F # noqa: N812
2423
from torch import Tensor
2524
from torchvision import transforms
25+
from torchvision.transforms.functional import to_pil_image
2626

2727
from pruna.engine.utils import set_to_best_available_device
2828
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
@@ -33,6 +33,14 @@
3333

3434
DINO_SCORE = "dino_score"
3535

36+
DINO_PREPROCESS = transforms.Compose(
37+
[
38+
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
39+
transforms.CenterCrop(224),
40+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
41+
]
42+
)
43+
3644

3745
@MetricRegistry.register(DINO_SCORE)
3846
class DinoScore(StatefulMetric):
@@ -41,49 +49,117 @@ class DinoScore(StatefulMetric):
4149
4250
A similarity metric based on DINO (self-distillation with no labels),
4351
a self-supervised vision transformer trained to learn high-level image representations without annotations.
44-
DinoScore compares the embeddings of generated and reference images in this representation space,
52+
DinoScore compares the [CLS] token embeddings of generated and reference images in this representation space,
4553
producing a value where higher scores indicate that the generated images preserve more of the semantic content of the
4654
reference images.
4755
48-
Reference
56+
Supports DINO (v1), DINOv2, and DINOv3 backbones. DINOv3 uses Hugging Face Transformers
57+
(facebook/dinov3-*) with weights on Hugging Face Hub. Requires transformers>=4.56.0.
58+
DINOv3 models are gated; accept the model at huggingface.co before first use.
59+
60+
References
4961
----------
50-
https://github.com/facebookresearch/dino
51-
https://arxiv.org/abs/2104.14294
62+
DINO: https://github.com/facebookresearch/dino, https://arxiv.org/abs/2104.14294
63+
DINOv2: https://github.com/facebookresearch/dinov2
64+
DINOv3: https://github.com/facebookresearch/dinov3
5265
5366
Parameters
5467
----------
5568
device : str | torch.device | None
5669
The device to use for the metric.
70+
model : str
71+
Backbone variant. "dino" uses timm vit_small_patch16_224.dino (DINO v1).
72+
"dinov2_*" uses torch.hub facebookresearch/dinov2. "dinov3_*" uses
73+
Hugging Face facebook/dinov3-* (ViT and ConvNeXt).
5774
call_type : str
5875
The call type to use for the metric.
5976
"""
6077

78+
DINOV3_HF_MODELS: dict[str, str] = {
79+
"dinov3_vits16": "facebook/dinov3-vits16-pretrain-lvd1689m",
80+
"dinov3_vits16plus": "facebook/dinov3-vits16plus-pretrain-lvd1689m",
81+
"dinov3_vitb16": "facebook/dinov3-vitb16-pretrain-lvd1689m",
82+
"dinov3_vitl16": "facebook/dinov3-vitl16-pretrain-lvd1689m",
83+
"dinov3_vith16plus": "facebook/dinov3-vith16plus-pretrain-lvd1689m",
84+
"dinov3_vit7b16": "facebook/dinov3-vit7b16-pretrain-lvd1689m",
85+
"dinov3_convnext_tiny": "facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
86+
"dinov3_convnext_small": "facebook/dinov3-convnext-small-pretrain-lvd1689m",
87+
"dinov3_convnext_base": "facebook/dinov3-convnext-base-pretrain-lvd1689m",
88+
"dinov3_convnext_large": "facebook/dinov3-convnext-large-pretrain-lvd1689m",
89+
"dinov3_vitl16_sat": "facebook/dinov3-vitl16-pretrain-sat493m",
90+
"dinov3_vit7b16_sat": "facebook/dinov3-vit7b16-pretrain-sat493m",
91+
}
92+
6193
similarities: List[Tensor]
6294
metric_name: str = DINO_SCORE
6395
higher_is_better: bool = True
64-
runs_on: List[str] = ["cuda", "cpu"]
96+
runs_on: List[str] = ["cuda", "cpu", "mps"]
6597
default_call_type: str = "gt_y"
6698

67-
def __init__(self, device: str | torch.device | None = None, call_type: str = SINGLE):
68-
super().__init__()
99+
def __init__(
100+
self,
101+
device: str | torch.device | None = None,
102+
model: str = "dino",
103+
call_type: str = SINGLE,
104+
):
105+
super().__init__(device=device)
69106
self.device = set_to_best_available_device(device)
70107
if device is not None and not any(self.device.startswith(prefix) for prefix in self.runs_on):
71108
pruna_logger.error(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}")
72109
raise
73110
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
74-
# Load the DINO ViT-S/16 model once
75-
self.model = timm.create_model("vit_small_patch16_224.dino", pretrained=True)
111+
self.model_name = model
112+
loaded = self._load_model(model)
113+
if isinstance(loaded, tuple):
114+
self.model, self._hf_processor = loaded
115+
self.processor = None
116+
else:
117+
self.model = loaded
118+
self._hf_processor = None
119+
self.processor = DINO_PREPROCESS
76120
self.model.eval().to(self.device)
77-
# Add internal state to accumulate similarities
78121
self.add_state("similarities", default=[])
79-
self.processor = transforms.Compose(
80-
[
81-
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
82-
transforms.CenterCrop(224),
83-
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
84-
]
122+
123+
def _load_model(
124+
self,
125+
model: str,
126+
) -> torch.nn.Module | tuple[torch.nn.Module, object]:
127+
if model == "dino":
128+
import timm
129+
return timm.create_model("vit_small_patch16_224.dino", pretrained=True)
130+
if model.startswith("dinov2_"):
131+
return torch.hub.load("facebookresearch/dinov2", model)
132+
if model in self.DINOV3_HF_MODELS:
133+
from transformers import AutoImageProcessor, AutoModel
134+
hf_id = self.DINOV3_HF_MODELS[model]
135+
processor = AutoImageProcessor.from_pretrained(hf_id)
136+
backbone = AutoModel.from_pretrained(hf_id)
137+
return backbone, processor
138+
raise ValueError(
139+
f"Unsupported model: {model}. "
140+
f"DINOv3 options: {list(self.DINOV3_HF_MODELS.keys())}"
85141
)
86142

143+
def _get_embeddings(self, x: Tensor) -> Tensor:
144+
if self.model_name == "dino":
145+
features = self.model.forward_features(x)
146+
return features[:, 0]
147+
if self.model_name.startswith("dinov2_"):
148+
out = self.model.forward_features(x)
149+
return out["x_norm_clstoken"]
150+
features = self.model.forward_features(x)
151+
if isinstance(features, dict):
152+
return features["x_norm_clstoken"]
153+
return features[:, 0]
154+
155+
def _get_embeddings_hf(self, x: Tensor) -> Tensor:
156+
images = [to_pil_image(x[i]) for i in range(x.shape[0])]
157+
inputs = self._hf_processor(images=images, return_tensors="pt")
158+
pixel_values = inputs["pixel_values"].to(self.device)
159+
with torch.no_grad():
160+
outputs = self.model(pixel_values)
161+
return outputs.pooler_output
162+
87163
@torch.no_grad()
88164
def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> None:
89165
"""
@@ -100,13 +176,14 @@ def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> No
100176
"""
101177
metric_inputs = metric_data_processor(x, gt, outputs, self.call_type)
102178
inputs, preds = metric_inputs
103-
inputs = self.processor(inputs)
104-
preds = self.processor(preds)
105-
# Extract embeddings ([CLS] token)
106-
emb_x = self.model.forward_features(inputs)
107-
emb_y = self.model.forward_features(preds)
108-
109-
# Normalize embeddings
179+
if self._hf_processor is not None:
180+
emb_x = self._get_embeddings_hf(inputs)
181+
emb_y = self._get_embeddings_hf(preds)
182+
else:
183+
inputs = self.processor(inputs)
184+
preds = self.processor(preds)
185+
emb_x = self._get_embeddings(inputs)
186+
emb_y = self._get_embeddings(preds)
110187
emb_x = F.normalize(emb_x, dim=1)
111188
emb_y = F.normalize(emb_y, dim=1)
112189

tests/evaluation/test_dino_score.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,51 @@
22
import pytest
33
from pruna.evaluation.metrics.metric_dino_score import DinoScore
44

5-
def test_dino_score():
6-
"""Test the DinoScore metric."""
7-
# Use CPU for testing
8-
metric = DinoScore(device="cpu")
5+
DINO_MODELS = [
6+
pytest.param("dino", id="dino_v1"),
7+
pytest.param("dinov2_vits14", id="dinov2_vits14", marks=pytest.mark.slow),
8+
pytest.param("dinov2_vitb14", id="dinov2_vitb14", marks=pytest.mark.slow),
9+
pytest.param("dinov2_vitl14", id="dinov2_vitl14", marks=pytest.mark.slow),
10+
pytest.param(
11+
"dinov3_vits16",
12+
id="dinov3_vits16",
13+
marks=[
14+
pytest.mark.slow,
15+
pytest.mark.skip(reason="facebook/dinov3-* are gated; accept at huggingface.co first"),
16+
],
17+
),
18+
pytest.param(
19+
"dinov3_convnext_tiny",
20+
id="dinov3_convnext_tiny",
21+
marks=[
22+
pytest.mark.slow,
23+
pytest.mark.skip(reason="facebook/dinov3-* are gated; accept at huggingface.co first"),
24+
],
25+
),
26+
]
927

10-
# Create dummy images (batch of 2 images, 3x224x224)
11-
x = torch.rand(2, 3, 224, 224)
12-
y = torch.rand(2,3, 224, 224)
1328

14-
# Update metric
29+
@pytest.mark.cpu
30+
@pytest.mark.parametrize("model", DINO_MODELS)
31+
def test_dino_score_models(model: str):
32+
"""Test DinoScore with each supported backbone (dino, dinov2, dinov3)."""
33+
metric = DinoScore(device="cpu", model=model)
34+
x = torch.rand(2, 3, 224, 224)
35+
y = torch.rand(2, 3, 224, 224)
1536
metric.update(x, y, y)
16-
17-
# Compute result
1837
result = metric.compute()
38+
assert result.name == "dino_score"
39+
assert isinstance(result.result, float)
40+
assert -1.0 - 1e-5 <= result.result <= 1.0 + 1e-5
1941

42+
43+
def test_dino_score():
44+
"""Test the DinoScore metric with default model (backward compatibility)."""
45+
metric = DinoScore(device="cpu")
46+
x = torch.rand(2, 3, 224, 224)
47+
y = torch.rand(2, 3, 224, 224)
48+
metric.update(x, y, y)
49+
result = metric.compute()
2050
assert result.name == "dino_score"
2151
assert isinstance(result.result, float)
22-
# Cosine similarity should be between -1 and 1
23-
assert -1.0 <= result.result <= 1.0
52+
assert -1.0 - 1e-5 <= result.result <= 1.0 + 1e-5

0 commit comments

Comments
 (0)