1414
1515from __future__ import annotations
1616
17- from typing import Any , List
17+ from typing import Any , List , Literal
1818
19- import timm
2019import torch
2120
2221# Ruff complains when we don't import functional as f, but common practice is to import it as F
2322import torch .nn .functional as F # noqa: N812
2423from torch import Tensor
2524from torchvision import transforms
25+ from torchvision .transforms .functional import to_pil_image
2626
2727from pruna .engine .utils import set_to_best_available_device
2828from pruna .evaluation .metrics .metric_stateful import StatefulMetric
3333
3434DINO_SCORE = "dino_score"
3535
36+ DINO_PREPROCESS = transforms .Compose (
37+ [
38+ transforms .Resize (256 , interpolation = transforms .InterpolationMode .BICUBIC ),
39+ transforms .CenterCrop (224 ),
40+ transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 )),
41+ ]
42+ )
43+
3644
3745@MetricRegistry .register (DINO_SCORE )
3846class DinoScore (StatefulMetric ):
@@ -41,49 +49,117 @@ class DinoScore(StatefulMetric):
4149
4250 A similarity metric based on DINO (self-distillation with no labels),
4351 a self-supervised vision transformer trained to learn high-level image representations without annotations.
44- DinoScore compares the embeddings of generated and reference images in this representation space,
52+ DinoScore compares the [CLS] token embeddings of generated and reference images in this representation space,
4553 producing a value where higher scores indicate that the generated images preserve more of the semantic content of the
4654 reference images.
4755
48- Reference
56+ Supports DINO (v1), DINOv2, and DINOv3 backbones. DINOv3 uses Hugging Face Transformers
57+ (facebook/dinov3-*) with weights on Hugging Face Hub. Requires transformers>=4.56.0.
58+ DINOv3 models are gated; accept the model at huggingface.co before first use.
59+
60+ References
4961 ----------
50- https://github.com/facebookresearch/dino
51- https://arxiv.org/abs/2104.14294
62+ DINO: https://github.com/facebookresearch/dino, https://arxiv.org/abs/2104.14294
63+ DINOv2: https://github.com/facebookresearch/dinov2
64+ DINOv3: https://github.com/facebookresearch/dinov3
5265
5366 Parameters
5467 ----------
5568 device : str | torch.device | None
5669 The device to use for the metric.
70+ model : str
71+ Backbone variant. "dino" uses timm vit_small_patch16_224.dino (DINO v1).
72+ "dinov2_*" uses torch.hub facebookresearch/dinov2. "dinov3_*" uses
73+ Hugging Face facebook/dinov3-* (ViT and ConvNeXt).
5774 call_type : str
5875 The call type to use for the metric.
5976 """
6077
78+ DINOV3_HF_MODELS : dict [str , str ] = {
79+ "dinov3_vits16" : "facebook/dinov3-vits16-pretrain-lvd1689m" ,
80+ "dinov3_vits16plus" : "facebook/dinov3-vits16plus-pretrain-lvd1689m" ,
81+ "dinov3_vitb16" : "facebook/dinov3-vitb16-pretrain-lvd1689m" ,
82+ "dinov3_vitl16" : "facebook/dinov3-vitl16-pretrain-lvd1689m" ,
83+ "dinov3_vith16plus" : "facebook/dinov3-vith16plus-pretrain-lvd1689m" ,
84+ "dinov3_vit7b16" : "facebook/dinov3-vit7b16-pretrain-lvd1689m" ,
85+ "dinov3_convnext_tiny" : "facebook/dinov3-convnext-tiny-pretrain-lvd1689m" ,
86+ "dinov3_convnext_small" : "facebook/dinov3-convnext-small-pretrain-lvd1689m" ,
87+ "dinov3_convnext_base" : "facebook/dinov3-convnext-base-pretrain-lvd1689m" ,
88+ "dinov3_convnext_large" : "facebook/dinov3-convnext-large-pretrain-lvd1689m" ,
89+ "dinov3_vitl16_sat" : "facebook/dinov3-vitl16-pretrain-sat493m" ,
90+ "dinov3_vit7b16_sat" : "facebook/dinov3-vit7b16-pretrain-sat493m" ,
91+ }
92+
6193 similarities : List [Tensor ]
6294 metric_name : str = DINO_SCORE
6395 higher_is_better : bool = True
64- runs_on : List [str ] = ["cuda" , "cpu" ]
96+ runs_on : List [str ] = ["cuda" , "cpu" , "mps" ]
6597 default_call_type : str = "gt_y"
6698
67- def __init__ (self , device : str | torch .device | None = None , call_type : str = SINGLE ):
68- super ().__init__ ()
99+ def __init__ (
100+ self ,
101+ device : str | torch .device | None = None ,
102+ model : str = "dino" ,
103+ call_type : str = SINGLE ,
104+ ):
105+ super ().__init__ (device = device )
69106 self .device = set_to_best_available_device (device )
70107 if device is not None and not any (self .device .startswith (prefix ) for prefix in self .runs_on ):
71108 pruna_logger .error (f"DinoScore: device { device } not supported. Supported devices: { self .runs_on } " )
72109 raise
73110 self .call_type = get_call_type_for_single_metric (call_type , self .default_call_type )
74- # Load the DINO ViT-S/16 model once
75- self .model = timm .create_model ("vit_small_patch16_224.dino" , pretrained = True )
111+ self .model_name = model
112+ loaded = self ._load_model (model )
113+ if isinstance (loaded , tuple ):
114+ self .model , self ._hf_processor = loaded
115+ self .processor = None
116+ else :
117+ self .model = loaded
118+ self ._hf_processor = None
119+ self .processor = DINO_PREPROCESS
76120 self .model .eval ().to (self .device )
77- # Add internal state to accumulate similarities
78121 self .add_state ("similarities" , default = [])
79- self .processor = transforms .Compose (
80- [
81- transforms .Resize (256 , interpolation = transforms .InterpolationMode .BICUBIC ),
82- transforms .CenterCrop (224 ),
83- transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 )),
84- ]
122+
123+ def _load_model (
124+ self ,
125+ model : str ,
126+ ) -> torch .nn .Module | tuple [torch .nn .Module , object ]:
127+ if model == "dino" :
128+ import timm
129+ return timm .create_model ("vit_small_patch16_224.dino" , pretrained = True )
130+ if model .startswith ("dinov2_" ):
131+ return torch .hub .load ("facebookresearch/dinov2" , model )
132+ if model in self .DINOV3_HF_MODELS :
133+ from transformers import AutoImageProcessor , AutoModel
134+ hf_id = self .DINOV3_HF_MODELS [model ]
135+ processor = AutoImageProcessor .from_pretrained (hf_id )
136+ backbone = AutoModel .from_pretrained (hf_id )
137+ return backbone , processor
138+ raise ValueError (
139+ f"Unsupported model: { model } . "
140+ f"DINOv3 options: { list (self .DINOV3_HF_MODELS .keys ())} "
85141 )
86142
143+ def _get_embeddings (self , x : Tensor ) -> Tensor :
144+ if self .model_name == "dino" :
145+ features = self .model .forward_features (x )
146+ return features [:, 0 ]
147+ if self .model_name .startswith ("dinov2_" ):
148+ out = self .model .forward_features (x )
149+ return out ["x_norm_clstoken" ]
150+ features = self .model .forward_features (x )
151+ if isinstance (features , dict ):
152+ return features ["x_norm_clstoken" ]
153+ return features [:, 0 ]
154+
155+ def _get_embeddings_hf (self , x : Tensor ) -> Tensor :
156+ images = [to_pil_image (x [i ]) for i in range (x .shape [0 ])]
157+ inputs = self ._hf_processor (images = images , return_tensors = "pt" )
158+ pixel_values = inputs ["pixel_values" ].to (self .device )
159+ with torch .no_grad ():
160+ outputs = self .model (pixel_values )
161+ return outputs .pooler_output
162+
87163 @torch .no_grad ()
88164 def update (self , x : List [Any ] | Tensor , gt : Tensor , outputs : torch .Tensor ) -> None :
89165 """
@@ -100,13 +176,14 @@ def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> No
100176 """
101177 metric_inputs = metric_data_processor (x , gt , outputs , self .call_type )
102178 inputs , preds = metric_inputs
103- inputs = self .processor (inputs )
104- preds = self .processor (preds )
105- # Extract embeddings ([CLS] token)
106- emb_x = self .model .forward_features (inputs )
107- emb_y = self .model .forward_features (preds )
108-
109- # Normalize embeddings
179+ if self ._hf_processor is not None :
180+ emb_x = self ._get_embeddings_hf (inputs )
181+ emb_y = self ._get_embeddings_hf (preds )
182+ else :
183+ inputs = self .processor (inputs )
184+ preds = self .processor (preds )
185+ emb_x = self ._get_embeddings (inputs )
186+ emb_y = self ._get_embeddings (preds )
110187 emb_x = F .normalize (emb_x , dim = 1 )
111188 emb_y = F .normalize (emb_y , dim = 1 )
112189
0 commit comments