Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 83 additions & 21 deletions src/pruna/evaluation/metrics/metric_dino_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor
from pruna.evaluation.metrics.utils import (
SINGLE,
get_call_type_for_single_metric,
metric_data_processor,
)
from pruna.logging.logger import pruna_logger

DINO_SCORE = "dino_score"
Expand All @@ -41,49 +45,112 @@ class DinoScore(StatefulMetric):

A similarity metric based on DINO (self-distillation with no labels),
a self-supervised vision transformer trained to learn high-level image representations without annotations.
DinoScore compares the embeddings of generated and reference images in this representation space,
DinoScore compares the [CLS] token embeddings of generated and reference images in this representation space,
producing a value where higher scores indicate that the generated images preserve more of the semantic content of the
reference images.

Reference
----------
https://github.com/facebookresearch/dino
https://arxiv.org/abs/2104.14294
DINO v1 and DINOv2 load via timm. DINOv3 loads via Hugging Face Transformers (>=4.56.0).

Parameters
----------
device : str | torch.device | None
The device to use for the metric.
model : str
One of the registered model keys. TIMM_MODELS keys: "dino", "dinov2_vits14",
"dinov2_vitb14", "dinov2_vitl14". HF_DINOV3_MODELS keys: "dinov3_vits16",
"dinov3_vits16plus", "dinov3_vitb16", "dinov3_vitl16", "dinov3_vith16plus",
"dinov3_vit7b16", "dinov3_convnext_tiny", "dinov3_convnext_small",
"dinov3_convnext_base", "dinov3_convnext_large", "dinov3_vitl16_sat493m",
"dinov3_vit7b16_sat493m".
call_type : str
The call type to use for the metric.

References
----------
DINO: https://github.com/facebookresearch/dino, https://arxiv.org/abs/2104.14294
DINOv2: https://github.com/facebookresearch/dinov2
DINOv3: https://github.com/facebookresearch/dinov3
"""

TIMM_MODELS: dict[str, str] = {
"dino": "vit_small_patch16_224.dino",
"dinov2_vits14": "vit_small_patch14_dinov2.lvd142m",
"dinov2_vitb14": "vit_base_patch14_dinov2.lvd142m",
"dinov2_vitl14": "vit_large_patch14_dinov2.lvd142m",
}

HF_DINOV3_MODELS: dict[str, str] = {
"dinov3_vits16": "facebook/dinov3-vits16-pretrain-lvd1689m",
"dinov3_vits16plus": "facebook/dinov3-vits16plus-pretrain-lvd1689m",
"dinov3_vitb16": "facebook/dinov3-vitb16-pretrain-lvd1689m",
"dinov3_vitl16": "facebook/dinov3-vitl16-pretrain-lvd1689m",
"dinov3_vith16plus": "facebook/dinov3-vith16plus-pretrain-lvd1689m",
"dinov3_vit7b16": "facebook/dinov3-vit7b16-pretrain-lvd1689m",
"dinov3_convnext_tiny": "facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
"dinov3_convnext_small": "facebook/dinov3-convnext-small-pretrain-lvd1689m",
"dinov3_convnext_base": "facebook/dinov3-convnext-base-pretrain-lvd1689m",
"dinov3_convnext_large": "facebook/dinov3-convnext-large-pretrain-lvd1689m",
"dinov3_vitl16_sat493m": "facebook/dinov3-vitl16-pretrain-sat493m",
"dinov3_vit7b16_sat493m": "facebook/dinov3-vit7b16-pretrain-sat493m",
}

@classmethod
def valid_models(cls) -> list[str]:
"""Return all valid model keys."""
return list(cls.TIMM_MODELS) + list(cls.HF_DINOV3_MODELS)

similarities: List[Tensor]
metric_name: str = DINO_SCORE
higher_is_better: bool = True
runs_on: List[str] = ["cuda", "cpu"]
default_call_type: str = "gt_y"

def __init__(self, device: str | torch.device | None = None, call_type: str = SINGLE):
super().__init__()
def __init__(
self,
device: str | torch.device | None = None,
model: str = "dino",
call_type: str = SINGLE,
):
super().__init__(device=device)
self.device = set_to_best_available_device(device)
if device is not None and not any(self.device.startswith(prefix) for prefix in self.runs_on):
pruna_logger.error(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}")
raise
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
# Load the DINO ViT-S/16 model once
self.model = timm.create_model("vit_small_patch16_224.dino", pretrained=True)
self.model.eval().to(self.device)
# Add internal state to accumulate similarities
valid = self.valid_models()
if model not in valid:
raise ValueError(f"Unknown DinoScore model '{model}'. Valid keys: {valid}")

if model in self.HF_DINOV3_MODELS:
from transformers import AutoModel

self.model = AutoModel.from_pretrained(self.HF_DINOV3_MODELS[model])
self.model.eval().to(self.device)
self._use_transformers = True
h = 224
else:
self.model = timm.create_model(self.TIMM_MODELS[model], pretrained=True)
self.model.eval().to(self.device)
self._use_transformers = False
h = self.model.default_cfg.get("input_size", (3, 224, 224))[1]

self.add_state("similarities", default=[])
self.processor = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Resize(int(h * 256 / 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(h),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)

def _get_embeddings(self, x: Tensor) -> Tensor:
if self._use_transformers:
out = self.model(pixel_values=x)
return out.pooler_output
else:
features = self.model.forward_features(x)
return features["x_norm_clstoken"] if isinstance(features, dict) else features[:, 0]

@torch.no_grad()
def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> None:
"""
Expand All @@ -102,15 +169,10 @@ def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> No
inputs, preds = metric_inputs
inputs = self.processor(inputs)
preds = self.processor(preds)
# Extract embeddings ([CLS] token)
emb_x = self.model.forward_features(inputs)
emb_y = self.model.forward_features(preds)

# Normalize embeddings
emb_x = self._get_embeddings(inputs)
emb_y = self._get_embeddings(preds)
emb_x = F.normalize(emb_x, dim=1)
emb_y = F.normalize(emb_y, dim=1)

# Compute cosine similarity
sim = (emb_x * emb_y).sum(dim=1)
self.similarities.append(sim)

Expand Down
57 changes: 45 additions & 12 deletions tests/evaluation/test_dino_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,55 @@
import pytest
from pruna.evaluation.metrics.metric_dino_score import DinoScore

def test_dino_score():
"""Test the DinoScore metric."""
# Use CPU for testing
metric = DinoScore(device="cpu")
DINO_MODELS = [
"dino",
pytest.param("dinov2_vits14", marks=pytest.mark.slow),
pytest.param("dinov2_vitb14", marks=pytest.mark.slow),
pytest.param("dinov2_vitl14", marks=pytest.mark.slow),
pytest.param(
"dinov3_vits16",
marks=[
pytest.mark.slow,
pytest.mark.skip(reason="DINOv3 HF models are gated; requires access approval"),
],
),
pytest.param(
"dinov3_convnext_tiny",
marks=[
pytest.mark.slow,
pytest.mark.skip(reason="DINOv3 HF models are gated; requires access approval"),
],
),
]

# Create dummy images (batch of 2 images, 3x224x224)
x = torch.rand(2, 3, 224, 224)
y = torch.rand(2,3, 224, 224)

# Update metric
@pytest.mark.cpu
@pytest.mark.parametrize("model", DINO_MODELS)
def test_dino_score_models(model: str):
"""Test DinoScore with each supported backbone (dino, dinov2, dinov3)."""
metric = DinoScore(device="cpu", model=model)
x = torch.rand(2, 3, 224, 224)
y = torch.rand(2, 3, 224, 224)
metric.update(x, y, y)

# Compute result
result = metric.compute()
assert result.name == "dino_score"
assert isinstance(result.result, float)
assert -1.0 - 1e-5 <= result.result <= 1.0 + 1e-5


def test_dino_score_invalid_model():
"""Test that an unrecognised model key raises a clear ValueError."""
with pytest.raises(ValueError, match="Unknown DinoScore model"):
DinoScore(device="cpu", model="facebook/dinov3-irrelevant-wrong-model")


def test_dino_score():
"""Test the DinoScore metric with default model (backward compatibility)."""
metric = DinoScore(device="cpu")
x = torch.rand(2, 3, 224, 224)
y = torch.rand(2, 3, 224, 224)
metric.update(x, y, y)
result = metric.compute()
assert result.name == "dino_score"
assert isinstance(result.result, float)
# Cosine similarity should be between -1 and 1
assert -1.0 <= result.result <= 1.0
assert -1.0 - 1e-5 <= result.result <= 1.0 + 1e-5
Loading