2828from pruna .evaluation .metrics .metric_stateful import StatefulMetric
2929from pruna .evaluation .metrics .registry import MetricRegistry
3030from pruna .evaluation .metrics .result import MetricResult
31- from pruna .evaluation .metrics .utils import SINGLE , get_call_type_for_single_metric , metric_data_processor
31+ from pruna .evaluation .metrics .utils import (
32+ SINGLE ,
33+ get_call_type_for_single_metric ,
34+ metric_data_processor ,
35+ )
3236from pruna .logging .logger import pruna_logger
3337
3438DINO_SCORE = "dino_score"
@@ -41,50 +45,113 @@ class DinoScore(StatefulMetric):
4145
4246 A similarity metric based on DINO (self-distillation with no labels),
4347 a self-supervised vision transformer trained to learn high-level image representations without annotations.
44- DinoScore compares the embeddings of generated and reference images in this representation space,
48+ DinoScore compares the [CLS] token embeddings of generated and reference images in this representation space,
4549 producing a value where higher scores indicate that the generated images preserve more of the semantic content of the
4650 reference images.
4751
48- Reference
49- ----------
50- https://github.com/facebookresearch/dino
51- https://arxiv.org/abs/2104.14294
52+ DINO v1 and DINOv2 load via timm. DINOv3 loads via Hugging Face Transformers (>=4.56.0).
5253
5354 Parameters
5455 ----------
5556 device : str | torch.device | None
5657 The device to use for the metric.
58+ model : str
59+ One of the registered model keys. TIMM_MODELS keys: "dino", "dinov2_vits14",
60+ "dinov2_vitb14", "dinov2_vitl14". HF_DINOV3_MODELS keys: "dinov3_vits16",
61+ "dinov3_vits16plus", "dinov3_vitb16", "dinov3_vitl16", "dinov3_vith16plus",
62+ "dinov3_vit7b16", "dinov3_convnext_tiny", "dinov3_convnext_small",
63+ "dinov3_convnext_base", "dinov3_convnext_large", "dinov3_vitl16_sat493m",
64+ "dinov3_vit7b16_sat493m".
5765 call_type : str
5866 The call type to use for the metric.
67+
68+ References
69+ ----------
70+ DINO: https://github.com/facebookresearch/dino, https://arxiv.org/abs/2104.14294
71+ DINOv2: https://github.com/facebookresearch/dinov2
72+ DINOv3: https://github.com/facebookresearch/dinov3
5973 """
6074
75+ TIMM_MODELS : dict [str , str ] = {
76+ "dino" : "vit_small_patch16_224.dino" ,
77+ "dinov2_vits14" : "vit_small_patch14_dinov2.lvd142m" ,
78+ "dinov2_vitb14" : "vit_base_patch14_dinov2.lvd142m" ,
79+ "dinov2_vitl14" : "vit_large_patch14_dinov2.lvd142m" ,
80+ }
81+
82+ HF_DINOV3_MODELS : dict [str , str ] = {
83+ "dinov3_vits16" : "facebook/dinov3-vits16-pretrain-lvd1689m" ,
84+ "dinov3_vits16plus" : "facebook/dinov3-vits16plus-pretrain-lvd1689m" ,
85+ "dinov3_vitb16" : "facebook/dinov3-vitb16-pretrain-lvd1689m" ,
86+ "dinov3_vitl16" : "facebook/dinov3-vitl16-pretrain-lvd1689m" ,
87+ "dinov3_vith16plus" : "facebook/dinov3-vith16plus-pretrain-lvd1689m" ,
88+ "dinov3_vit7b16" : "facebook/dinov3-vit7b16-pretrain-lvd1689m" ,
89+ "dinov3_convnext_tiny" : "facebook/dinov3-convnext-tiny-pretrain-lvd1689m" ,
90+ "dinov3_convnext_small" : "facebook/dinov3-convnext-small-pretrain-lvd1689m" ,
91+ "dinov3_convnext_base" : "facebook/dinov3-convnext-base-pretrain-lvd1689m" ,
92+ "dinov3_convnext_large" : "facebook/dinov3-convnext-large-pretrain-lvd1689m" ,
93+ "dinov3_vitl16_sat493m" : "facebook/dinov3-vitl16-pretrain-sat493m" ,
94+ "dinov3_vit7b16_sat493m" : "facebook/dinov3-vit7b16-pretrain-sat493m" ,
95+ }
96+
97+ @classmethod
98+ def valid_models (cls ) -> list [str ]:
99+ """Return all valid model keys."""
100+ return list (cls .TIMM_MODELS ) + list (cls .HF_DINOV3_MODELS )
101+
61102 similarities : List [Tensor ]
62103 metric_name : str = DINO_SCORE
63104 higher_is_better : bool = True
64105 runs_on : List [str ] = ["cuda" , "cpu" ]
65106 default_call_type : str = "gt_y"
66107
67- def __init__ (self , device : str | torch .device | None = None , call_type : str = SINGLE ):
68- super ().__init__ ()
108+ def __init__ (
109+ self ,
110+ device : str | torch .device | None = None ,
111+ model : str = "dino" ,
112+ call_type : str = SINGLE ,
113+ ):
114+ super ().__init__ (device = device )
69115 self .device = set_to_best_available_device (device )
70116 if device is not None and not any (self .device .startswith (prefix ) for prefix in self .runs_on ):
71117 msg = f"DinoScore: device { device } not supported. Supported devices: { self .runs_on } "
72118 pruna_logger .error (msg )
73119 raise ValueError (msg )
74120 self .call_type = get_call_type_for_single_metric (call_type , self .default_call_type )
75- # Load the DINO ViT-S/16 model once
76- self .model = timm .create_model ("vit_small_patch16_224.dino" , pretrained = True )
77- self .model .eval ().to (self .device )
78- # Add internal state to accumulate similarities
121+ valid = self .valid_models ()
122+ if model not in valid :
123+ raise ValueError (f"Unknown DinoScore model '{ model } '. Valid keys: { valid } " )
124+
125+ if model in self .HF_DINOV3_MODELS :
126+ from transformers import AutoModel
127+
128+ self .model = AutoModel .from_pretrained (self .HF_DINOV3_MODELS [model ])
129+ self .model .eval ().to (self .device )
130+ self ._use_transformers = True
131+ h = 224
132+ else :
133+ self .model = timm .create_model (self .TIMM_MODELS [model ], pretrained = True )
134+ self .model .eval ().to (self .device )
135+ self ._use_transformers = False
136+ h = self .model .default_cfg .get ("input_size" , (3 , 224 , 224 ))[1 ]
137+
79138 self .add_state ("similarities" , default = [])
80139 self .processor = transforms .Compose (
81140 [
82- transforms .Resize (256 , interpolation = transforms .InterpolationMode .BICUBIC ),
83- transforms .CenterCrop (224 ),
141+ transforms .Resize (int ( h * 256 / 224 ) , interpolation = transforms .InterpolationMode .BICUBIC ),
142+ transforms .CenterCrop (h ),
84143 transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 )),
85144 ]
86145 )
87146
147+ def _get_embeddings (self , x : Tensor ) -> Tensor :
148+ if self ._use_transformers :
149+ out = self .model (pixel_values = x )
150+ return out .pooler_output
151+ else :
152+ features = self .model .forward_features (x )
153+ return features ["x_norm_clstoken" ] if isinstance (features , dict ) else features [:, 0 ]
154+
88155 @torch .no_grad ()
89156 def update (self , x : List [Any ] | Tensor , gt : Tensor , outputs : torch .Tensor ) -> None :
90157 """
@@ -103,15 +170,10 @@ def update(self, x: List[Any] | Tensor, gt: Tensor, outputs: torch.Tensor) -> No
103170 inputs , preds = metric_inputs
104171 inputs = self .processor (inputs )
105172 preds = self .processor (preds )
106- # Extract embeddings ([CLS] token)
107- emb_x = self .model .forward_features (inputs )
108- emb_y = self .model .forward_features (preds )
109-
110- # Normalize embeddings
173+ emb_x = self ._get_embeddings (inputs )
174+ emb_y = self ._get_embeddings (preds )
111175 emb_x = F .normalize (emb_x , dim = 1 )
112176 emb_y = F .normalize (emb_y , dim = 1 )
113-
114- # Compute cosine similarity
115177 sim = (emb_x * emb_y ).sum (dim = 1 )
116178 self .similarities .append (sim )
117179
0 commit comments