@@ -50,20 +50,18 @@ class DinoScore(StatefulMetric):
5050 reference images.
5151
5252 DINO v1 and DINOv2 load via timm. DINOv3 loads via Hugging Face Transformers (>=4.56.0).
53- See https://github.com/facebookresearch/dinov3 and
54- https://huggingface.co/collections/facebook/dinov3 for available models.
5553
5654 Parameters
5755 ----------
5856 device : str | torch.device | None
5957 The device to use for the metric.
6058 model : str
61- Backbone name. "dino" (default), "dinov2_vits14 ", "dinov2_vitb14 ",
62- "dinov2_vitl14 ", "dinov3_vits16", "dinov3_vits16plus", "dinov3_vitb16 ",
63- "dinov3_vitl16 ", "dinov3_vith16plus ", "dinov3_vit7b16 ",
64- "dinov3_convnext_tiny/small/base/large ", "dinov3_vitl16_sat493m ",
65- "dinov3_vit7b16_sat493m ", etc. DINOv3 uses HF Transformers; DINO v1/v2
66- use timm. Any timm or HF model ID also accepted .
59+ One of the registered model keys. TIMM_MODELS keys: "dino ", "dinov2_vits14 ",
60+ "dinov2_vitb14 ", "dinov2_vitl14". HF_DINOV3_MODELS keys: "dinov3_vits16 ",
61+ "dinov3_vits16plus ", "dinov3_vitb16 ", "dinov3_vitl16", "dinov3_vith16plus ",
62+ "dinov3_vit7b16", " dinov3_convnext_tiny", "dinov3_convnext_small ",
63+ "dinov3_convnext_base ", "dinov3_convnext_large", "dinov3_vitl16_sat493m",
64+ "dinov3_vit7b16_sat493m" .
6765 call_type : str
6866 The call type to use for the metric.
6967
@@ -96,6 +94,11 @@ class DinoScore(StatefulMetric):
9694 "dinov3_vit7b16_sat493m" : "facebook/dinov3-vit7b16-pretrain-sat493m" ,
9795 }
9896
97+ @classmethod
98+ def valid_models (cls ) -> list [str ]:
99+ """Return all valid model keys."""
100+ return list (cls .TIMM_MODELS ) + list (cls .HF_DINOV3_MODELS )
101+
99102 similarities : List [Tensor ]
100103 metric_name : str = DINO_SCORE
101104 higher_is_better : bool = True
@@ -114,20 +117,19 @@ def __init__(
114117 pruna_logger .error (f"DinoScore: device { device } not supported. Supported devices: { self .runs_on } " )
115118 raise
116119 self .call_type = get_call_type_for_single_metric (call_type , self .default_call_type )
117- self .model_name = model
120+ valid = self .valid_models ()
121+ if model not in valid :
122+ raise ValueError (f"Unknown DinoScore model '{ model } '. Valid keys: { valid } " )
118123
119- hf_name = self .HF_DINOV3_MODELS .get (model )
120- if hf_name is not None or (model .startswith ("facebook/" ) and "dinov3" in model ):
124+ if model in self .HF_DINOV3_MODELS :
121125 from transformers import AutoModel
122126
123- hf_name = hf_name or model
124- self .model = AutoModel .from_pretrained (hf_name )
127+ self .model = AutoModel .from_pretrained (self .HF_DINOV3_MODELS [model ])
125128 self .model .eval ().to (self .device )
126129 self ._use_transformers = True
127130 h = 224
128131 else :
129- timm_name = self .TIMM_MODELS .get (model , model )
130- self .model = timm .create_model (timm_name , pretrained = True )
132+ self .model = timm .create_model (self .TIMM_MODELS [model ], pretrained = True )
131133 self .model .eval ().to (self .device )
132134 self ._use_transformers = False
133135 h = self .model .default_cfg .get ("input_size" , (3 , 224 , 224 ))[1 ]
@@ -145,8 +147,9 @@ def _get_embeddings(self, x: Tensor) -> Tensor:
145147 if self ._use_transformers :
146148 out = self .model (pixel_values = x )
147149 return out .pooler_output
148- features = self .model .forward_features (x )
149- return features ["x_norm_clstoken" ] if isinstance (features , dict ) else features [:, 0 ]
150+ else :
151+ features = self .model .forward_features (x )
152+ return features ["x_norm_clstoken" ] if isinstance (features , dict ) else features [:, 0 ]
150153
151154 @torch .no_grad ()
152155 def update (self , x : List [Any ] | Tensor , gt : Tensor , outputs : torch .Tensor ) -> None :
0 commit comments