1414
1515from __future__ import annotations
1616
17- from typing import Any , List
17+ from typing import Any , List , Literal
1818
19- import timm
2019import torch
2120
2221# Ruff complains when we don't import functional as f, but common practice is to import it as F
3332
3433DINO_SCORE = "dino_score"
3534
35+ DINO_PREPROCESS = transforms .Compose (
36+ [
37+ transforms .Resize (256 , interpolation = transforms .InterpolationMode .BICUBIC ),
38+ transforms .CenterCrop (224 ),
39+ transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 )),
40+ ]
41+ )
42+
3643
3744@MetricRegistry .register (DINO_SCORE )
3845class DinoScore (StatefulMetric ):
@@ -41,48 +48,97 @@ class DinoScore(StatefulMetric):
4148
4249 A similarity metric based on DINO (self-distillation with no labels),
4350 a self-supervised vision transformer trained to learn high-level image representations without annotations.
44- DinoScore compares the embeddings of generated and reference images in this representation space,
51+ DinoScore compares the [CLS] token embeddings of generated and reference images in this representation space,
4552 producing a value where higher scores indicate that the generated images preserve more of the semantic content of the
4653 reference images.
4754
48- Reference
55+ Supports DINO (v1), DINOv2, and DINOv3 backbones. DINOv3 models may require weights from Meta's download form.
56+
57+ References
4958 ----------
50- https://github.com/facebookresearch/dino
51- https://arxiv.org/abs/2104.14294
59+ DINO: https://github.com/facebookresearch/dino, https://arxiv.org/abs/2104.14294
60+ DINOv2: https://github.com/facebookresearch/dinov2
61+ DINOv3: https://github.com/facebookresearch/dinov3
5262
5363 Parameters
5464 ----------
5565 device : str | torch.device | None
5666 The device to use for the metric.
67+ model : {"dino", "dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov3_vits16", "dinov3_vitb16", "dinov3_vitl16"}
68+ Backbone variant. "dino" uses timm vit_small_patch16_224.dino (DINO v1).
69+ "dinov2_*" uses torch.hub facebookresearch/dinov2. "dinov3_*" uses timm (requires timm>=1.0.20).
5770 call_type : str
5871 The call type to use for the metric.
5972 """
6073
6174 similarities : List [Tensor ]
6275 metric_name : str = DINO_SCORE
6376 higher_is_better : bool = True
64- runs_on : List [str ] = ["cuda" , "cpu" ]
77+ runs_on : List [str ] = ["cuda" , "cpu" , "mps" ]
6578 default_call_type : str = "gt_y"
6679
67- def __init__ (self , device : str | torch .device | None = None , call_type : str = SINGLE ):
68- super ().__init__ ()
80+ def __init__ (
81+ self ,
82+ device : str | torch .device | None = None ,
83+ model : Literal [
84+ "dino" , "dinov2_vits14" , "dinov2_vitb14" , "dinov2_vitl14" , "dinov3_vits16" , "dinov3_vitb16" , "dinov3_vitl16"
85+ ] = "dino" ,
86+ call_type : str = SINGLE ,
87+ ):
88+ super ().__init__ (device = device )
6989 self .device = set_to_best_available_device (device )
7090 if device is not None and not any (self .device .startswith (prefix ) for prefix in self .runs_on ):
7191 pruna_logger .error (f"DinoScore: device { device } not supported. Supported devices: { self .runs_on } " )
7292 raise
7393 self .call_type = get_call_type_for_single_metric (call_type , self .default_call_type )
74- # Load the DINO ViT-S/16 model once
75- self .model = timm . create_model ( "vit_small_patch16_224.dino" , pretrained = True )
94+ self . model_name = model
95+ self .model = self . _load_model ( model )
7696 self .model .eval ().to (self .device )
77- # Add internal state to accumulate similarities
7897 self .add_state ("similarities" , default = [])
79- self .processor = transforms .Compose (
80- [
81- transforms .Resize (256 , interpolation = transforms .InterpolationMode .BICUBIC ),
82- transforms .CenterCrop (224 ),
83- transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 )),
84- ]
85- )
98+ self .processor = DINO_PREPROCESS
99+
100+ def _load_model (
101+ self ,
102+ model : str ,
103+ ) -> torch .nn .Module :
104+ if model == "dino" :
105+ import timm
106+ return timm .create_model ("vit_small_patch16_224.dino" , pretrained = True )
107+ if model .startswith ("dinov2_" ):
108+ return torch .hub .load ("facebookresearch/dinov2" , model )
109+ if model .startswith ("dinov3_" ):
110+ import timm
111+ timm_map = {
112+ "dinov3_vits16" : "vit_small_patch16_dinov3.lvd1689m" ,
113+ "dinov3_vitb16" : "vit_base_patch16_dinov3.lvd1689m" ,
114+ "dinov3_vitl16" : "vit_large_patch16_dinov3.lvd1689m" ,
115+ }
116+ timm_name = timm_map .get (model )
117+ if timm_name is None :
118+ raise ValueError (f"Unsupported DINOv3 model: { model } . Choose from { list (timm_map .keys ())} " )
119+ try :
120+ return timm .create_model (timm_name , pretrained = True )
121+ except Exception as e :
122+ raise ValueError (
123+ f"DINOv3 requires timm>=1.0.20 and model weights from Meta. "
124+ f"See https://github.com/facebookresearch/dinov3. Error: { e } "
125+ ) from e
126+ raise ValueError (f"Unsupported model: { model } " )
127+
128+ def _get_embeddings (self , x : Tensor ) -> Tensor :
129+ if self .model_name == "dino" :
130+ features = self .model .forward_features (x )
131+ return features [:, 0 ]
132+ if self .model_name .startswith ("dinov2_" ):
133+ out = self .model .forward_features (x )
134+ return out ["x_norm_clstoken" ]
135+ if self .model_name .startswith ("dinov3_" ):
136+ features = self .model .forward_features (x )
137+ return features [:, 0 ]
138+ features = self .model .forward_features (x )
139+ if isinstance (features , dict ):
140+ return features ["x_norm_clstoken" ]
141+ return features [:, 0 ]
86142
87143 @torch .no_grad ()
88144 def update (self , x : List [Any ] | Tensor , gt : Tensor , outputs : torch .Tensor ) -> None :
@@ -102,11 +158,8 @@ def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> No
102158 inputs , preds = metric_inputs
103159 inputs = self .processor (inputs )
104160 preds = self .processor (preds )
105- # Extract embeddings ([CLS] token)
106- emb_x = self .model .forward_features (inputs )
107- emb_y = self .model .forward_features (preds )
108-
109- # Normalize embeddings
161+ emb_x = self ._get_embeddings (inputs )
162+ emb_y = self ._get_embeddings (preds )
110163 emb_x = F .normalize (emb_x , dim = 1 )
111164 emb_y = F .normalize (emb_y , dim = 1 )
112165
0 commit comments