Skip to content

Commit 9b49913

Browse files
feat(metrics): DINO Score CLS fix, multi-model support, paper docstring fixes
- DINO Score: fix CLS token extraction ([:,0] for v1/v3, x_norm_clstoken for v2) - DINO Score: add model options (dino, dinov2_vits14, dinov2_vitb14, dinov3_*) - DINO Score: add MPS support - VieScore, Image Edit Score, VQA: update docstrings per paper refs - VQA: add use_probability for P(Yes) via logprobs (litellm) - Add tests for each DINO model (parametrized, slow mark for dinov2) Made-with: Cursor
1 parent deab4b5 commit 9b49913

File tree

6 files changed

+205
-77
lines changed

6 files changed

+205
-77
lines changed

src/pruna/evaluation/metrics/metric_dino_score.py

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any, List
17+
from typing import Any, List, Literal
1818

19-
import timm
2019
import torch
2120

2221
# Ruff complains when we don't import functional as f, but common practice is to import it as F
@@ -33,6 +32,14 @@
3332

3433
DINO_SCORE = "dino_score"
3534

35+
DINO_PREPROCESS = transforms.Compose(
36+
[
37+
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
38+
transforms.CenterCrop(224),
39+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
40+
]
41+
)
42+
3643

3744
@MetricRegistry.register(DINO_SCORE)
3845
class DinoScore(StatefulMetric):
@@ -41,48 +48,97 @@ class DinoScore(StatefulMetric):
4148
4249
A similarity metric based on DINO (self-distillation with no labels),
4350
a self-supervised vision transformer trained to learn high-level image representations without annotations.
44-
DinoScore compares the embeddings of generated and reference images in this representation space,
51+
DinoScore compares the [CLS] token embeddings of generated and reference images in this representation space,
4552
producing a value where higher scores indicate that the generated images preserve more of the semantic content of the
4653
reference images.
4754
48-
Reference
55+
Supports DINO (v1), DINOv2, and DINOv3 backbones. DINOv3 models may require weights from Meta's download form.
56+
57+
References
4958
----------
50-
https://github.com/facebookresearch/dino
51-
https://arxiv.org/abs/2104.14294
59+
DINO: https://github.com/facebookresearch/dino, https://arxiv.org/abs/2104.14294
60+
DINOv2: https://github.com/facebookresearch/dinov2
61+
DINOv3: https://github.com/facebookresearch/dinov3
5262
5363
Parameters
5464
----------
5565
device : str | torch.device | None
5666
The device to use for the metric.
67+
model : {"dino", "dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov3_vits16", "dinov3_vitb16", "dinov3_vitl16"}
68+
Backbone variant. "dino" uses timm vit_small_patch16_224.dino (DINO v1).
69+
"dinov2_*" uses torch.hub facebookresearch/dinov2. "dinov3_*" uses timm (requires timm>=1.0.20).
5770
call_type : str
5871
The call type to use for the metric.
5972
"""
6073

6174
similarities: List[Tensor]
6275
metric_name: str = DINO_SCORE
6376
higher_is_better: bool = True
64-
runs_on: List[str] = ["cuda", "cpu"]
77+
runs_on: List[str] = ["cuda", "cpu", "mps"]
6578
default_call_type: str = "gt_y"
6679

67-
def __init__(self, device: str | torch.device | None = None, call_type: str = SINGLE):
68-
super().__init__()
80+
def __init__(
81+
self,
82+
device: str | torch.device | None = None,
83+
model: Literal[
84+
"dino", "dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov3_vits16", "dinov3_vitb16", "dinov3_vitl16"
85+
] = "dino",
86+
call_type: str = SINGLE,
87+
):
88+
super().__init__(device=device)
6989
self.device = set_to_best_available_device(device)
7090
if device is not None and not any(self.device.startswith(prefix) for prefix in self.runs_on):
7191
pruna_logger.error(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}")
7292
raise
7393
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
74-
# Load the DINO ViT-S/16 model once
75-
self.model = timm.create_model("vit_small_patch16_224.dino", pretrained=True)
94+
self.model_name = model
95+
self.model = self._load_model(model)
7696
self.model.eval().to(self.device)
77-
# Add internal state to accumulate similarities
7897
self.add_state("similarities", default=[])
79-
self.processor = transforms.Compose(
80-
[
81-
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
82-
transforms.CenterCrop(224),
83-
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
84-
]
85-
)
98+
self.processor = DINO_PREPROCESS
99+
100+
def _load_model(
101+
self,
102+
model: str,
103+
) -> torch.nn.Module:
104+
if model == "dino":
105+
import timm
106+
return timm.create_model("vit_small_patch16_224.dino", pretrained=True)
107+
if model.startswith("dinov2_"):
108+
return torch.hub.load("facebookresearch/dinov2", model)
109+
if model.startswith("dinov3_"):
110+
import timm
111+
timm_map = {
112+
"dinov3_vits16": "vit_small_patch16_dinov3.lvd1689m",
113+
"dinov3_vitb16": "vit_base_patch16_dinov3.lvd1689m",
114+
"dinov3_vitl16": "vit_large_patch16_dinov3.lvd1689m",
115+
}
116+
timm_name = timm_map.get(model)
117+
if timm_name is None:
118+
raise ValueError(f"Unsupported DINOv3 model: {model}. Choose from {list(timm_map.keys())}")
119+
try:
120+
return timm.create_model(timm_name, pretrained=True)
121+
except Exception as e:
122+
raise ValueError(
123+
f"DINOv3 requires timm>=1.0.20 and model weights from Meta. "
124+
f"See https://github.com/facebookresearch/dinov3. Error: {e}"
125+
) from e
126+
raise ValueError(f"Unsupported model: {model}")
127+
128+
def _get_embeddings(self, x: Tensor) -> Tensor:
129+
if self.model_name == "dino":
130+
features = self.model.forward_features(x)
131+
return features[:, 0]
132+
if self.model_name.startswith("dinov2_"):
133+
out = self.model.forward_features(x)
134+
return out["x_norm_clstoken"]
135+
if self.model_name.startswith("dinov3_"):
136+
features = self.model.forward_features(x)
137+
return features[:, 0]
138+
features = self.model.forward_features(x)
139+
if isinstance(features, dict):
140+
return features["x_norm_clstoken"]
141+
return features[:, 0]
86142

87143
@torch.no_grad()
88144
def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> None:
@@ -102,11 +158,8 @@ def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> No
102158
inputs, preds = metric_inputs
103159
inputs = self.processor(inputs)
104160
preds = self.processor(preds)
105-
# Extract embeddings ([CLS] token)
106-
emb_x = self.model.forward_features(inputs)
107-
emb_y = self.model.forward_features(preds)
108-
109-
# Normalize embeddings
161+
emb_x = self._get_embeddings(inputs)
162+
emb_y = self._get_embeddings(preds)
110163
emb_x = F.normalize(emb_x, dim=1)
111164
emb_y = F.normalize(emb_y, dim=1)
112165

src/pruna/evaluation/metrics/metric_img_edit_score.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""
1616
Image Edit Score metric.
1717
18-
Reference: VieScore https://github.com/ByteDance/IEA-eval
18+
VLM-based instruction-following score for image editing. Evaluates how well an edited image
19+
follows the given editing instruction on a 0-10 scale. Related work: EditScore (arXiv:2509.23909),
20+
ADIEE (ICCV 2025).
1921
"""
2022

2123
from __future__ import annotations
@@ -40,8 +42,10 @@ class ImageEditScoreMetric(StatefulMetric):
4042
"""
4143
Image Edit Score metric.
4244
43-
Evaluates how well an image was edited based on editing instructions.
44-
Higher scores indicate better editing quality.
45+
VLM-based instruction-following score for image editing. Evaluates how well an edited image
46+
follows the given editing instruction. Higher scores indicate better editing quality.
47+
48+
Related work: EditScore (arXiv:2509.23909), ADIEE (ICCV 2025).
4549
4650
Parameters
4751
----------

src/pruna/evaluation/metrics/metric_viescore.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515
"""
16-
VieScore metric for evaluating image quality (semantic + quality).
16+
VIEScore metric for evaluating conditional image synthesis (semantic + quality).
1717
18-
Reference: VieScore https://github.com/ByteDance/IEA-eval
18+
Reference: VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation
19+
(ACL 2024) - https://arxiv.org/abs/2312.14867, https://github.com/TIGER-AI-Lab/VIEScore
1920
"""
2021

2122
from __future__ import annotations
@@ -39,7 +40,7 @@
3940
@MetricRegistry.register("viescore")
4041
class VieScoreMetric(StatefulMetric):
4142
"""
42-
VieScore metric for evaluating image quality (semantic + quality).
43+
VIEScore metric for evaluating conditional image synthesis (semantic + quality).
4344
4445
Uses VLM to assess both semantic alignment and visual quality.
4546
Higher scores indicate better overall quality.
@@ -49,6 +50,12 @@ class VieScoreMetric(StatefulMetric):
4950
- Quality score: Naturalness and artifacts
5051
- Overall: Geometric mean of semantic and quality
5152
53+
References
54+
----------
55+
VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation (ACL 2024)
56+
https://arxiv.org/abs/2312.14867
57+
https://github.com/TIGER-AI-Lab/VIEScore
58+
5259
Parameters
5360
----------
5461
*args : Any

src/pruna/evaluation/metrics/metric_vqa.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
"""
1616
VQA (Visual Question Answering) metric.
1717
18-
Reference: VQAScore https://arxiv.org/abs/2310.08868
18+
Reference: VQAScore - Evaluating Text-to-Visual Generation with Image-to-Text Generation
19+
https://arxiv.org/abs/2404.01291
20+
21+
Note: VQAScore uses P(Yes) (probability of "Yes" answer) for ranking. This implementation
22+
defaults to binary (0/1) for compatibility. Set use_probability=True when using litellm
23+
with a provider that supports logprobs to get soft scores.
1924
"""
2025

2126
from __future__ import annotations
@@ -39,9 +44,12 @@ class VQAMetric(StatefulMetric):
3944
"""
4045
VQA (Visual Question Answering) metric.
4146
42-
Uses VLM to answer questions about images and compare with expected answers.
47+
Uses VLM to answer "Does this image show '{prompt}'?" and scores alignment.
4348
Higher scores indicate better image-text alignment.
4449
50+
VQAScore (arXiv:2404.01291) uses P(Yes) for ranking. Default is binary (0/1).
51+
Set use_probability=True with litellm + logprobs-capable provider for soft scores.
52+
4553
Parameters
4654
----------
4755
*args : Any
@@ -64,6 +72,9 @@ class VQAMetric(StatefulMetric):
6472
API key for litellm.
6573
call_type : str, optional
6674
Call type for the metric.
75+
use_probability : bool, optional
76+
If True, use P(Yes) when backend supports logprobs (litellm). Otherwise binary 0/1.
77+
Default is False for backward compatibility.
6778
**kwargs : Any
6879
Additional arguments.
6980
"""
@@ -86,11 +97,13 @@ def __init__(
8697
device=None,
8798
api_key: Optional[str] = None,
8899
call_type: str = SINGLE,
100+
use_probability: bool = False,
89101
**kwargs,
90102
):
91103
super().__init__(device=device)
92104
self.device = set_to_best_available_device(device)
93105
self.structured_output = structured_output
106+
self.use_probability = use_probability
94107

95108
self.vlm = get_vlm(
96109
vlm=vlm,
@@ -117,7 +130,13 @@ def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.T
117130
for i, image in enumerate(images):
118131
prompt = prompts[i] if i < len(prompts) else ""
119132
question = f'Does this image show "{prompt}"?'
120-
score = self.vlm.score([image], [question], ["Yes"], response_format=self.response_format)[0]
133+
score = self.vlm.score(
134+
[image],
135+
[question],
136+
["Yes"],
137+
response_format=self.response_format,
138+
use_probability=self.use_probability,
139+
)[0]
121140
self.scores.append(score)
122141

123142
def compute(self) -> MetricResult:

0 commit comments

Comments
 (0)