55"""Copy from mmpretrain/models/backbones/vision_transformer.py."""
66from __future__ import annotations
77
8+ import math
89from functools import partial
910from typing import TYPE_CHECKING , Any , Callable , Literal
1011
4647 "vit-huge" ,
4748 "dinov2-s" ,
4849 "dinov2-small" ,
50+ "dinov2-small-seg" ,
4951 "dinov2-b" ,
5052 "dinov2-base" ,
5153 "dinov2-l" ,
@@ -87,6 +89,7 @@ class VisionTransformer(BaseModule):
8789 norm_layer: Normalization layer.
8890 act_layer: MLP activation layer.
8991 block_fn: Transformer block layer.
92+ interpolate_offset: work-around offset to apply when interpolating positional embeddings
9093 lora: Enable LoRA training.
9194 """
9295
@@ -147,6 +150,17 @@ class VisionTransformer(BaseModule):
147150 "num_heads" : 6 ,
148151 "reg_tokens" : 4 ,
149152 "no_embed_class" : True ,
153+ },
154+ ),
155+ ** dict .fromkeys (
156+ ["dinov2-small-seg" ], # segmentation
157+ {
158+ "patch_size" : 14 ,
159+ "embed_dim" : 384 ,
160+ "depth" : 12 ,
161+ "num_heads" : 6 ,
162+ "reg_tokens" : 0 ,
163+ "no_embed_class" : False ,
150164 "init_values" : 1e-5 ,
151165 },
152166 ),
@@ -193,9 +207,9 @@ class VisionTransformer(BaseModule):
193207
194208 def __init__ ( # noqa: PLR0913
195209 self ,
196- arch : VIT_ARCH_TYPE = "vit-base" ,
210+ arch : VIT_ARCH_TYPE | str = "vit-base" ,
197211 img_size : int | tuple [int , int ] = 224 ,
198- patch_size : int | tuple [ int , int ] | None = None ,
212+ patch_size : int | None = None ,
199213 in_chans : int = 3 ,
200214 num_classes : int = 1000 ,
201215 embed_dim : int | None = None ,
@@ -221,6 +235,7 @@ def __init__( # noqa: PLR0913
221235 mlp_layer : nn .Module | None = None ,
222236 act_layer : LayerType | None = None ,
223237 norm_layer : LayerType | None = None ,
238+ interpolate_offset : float = 0.1 ,
224239 lora : bool = False ,
225240 ) -> None :
226241 super ().__init__ ()
@@ -231,7 +246,7 @@ def __init__( # noqa: PLR0913
231246 arch_settings : dict [str , Any ] = self .arch_zoo [arch ]
232247
233248 self .img_size : int | tuple [int , int ] = img_size
234- self .patch_size : int | tuple [ int , int ] = patch_size or arch_settings .get ("patch_size" , 16 )
249+ self .patch_size : int = patch_size or arch_settings .get ("patch_size" , 16 )
235250 self .embed_dim = embed_dim or arch_settings .get ("embed_dim" , 768 )
236251 depth = depth or arch_settings .get ("depth" , 12 )
237252 num_heads = num_heads or arch_settings .get ("num_heads" , 12 )
@@ -251,6 +266,7 @@ def __init__( # noqa: PLR0913
251266 self .no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
252267 self .dynamic_img_size = dynamic_img_size
253268 self .grad_checkpointing = False
269+ self .interpolate_offset = interpolate_offset
254270
255271 embed_args = {}
256272 if dynamic_img_size :
@@ -353,15 +369,17 @@ def resize_positional_embeddings(pos_embed: torch.Tensor, new_shape: tuple[int,
353369 # convert dinov2 pretrained weights
354370 state_dict = torch .load (checkpoint_path )
355371 state_dict .pop ("mask_token" , None )
356- state_dict ["reg_token" ] = state_dict .pop ("register_tokens" )
372+ if "reg_token" in state_dict :
373+ state_dict ["reg_token" ] = state_dict .pop ("register_tokens" )
357374 state_dict ["cls_token" ] = state_dict .pop ("cls_token" ) + state_dict ["pos_embed" ][:, 0 ]
358375
359376 img_size = (self .img_size , self .img_size ) if isinstance (self .img_size , int ) else self .img_size
360- patch_size = (self .patch_size , self .patch_size ) if isinstance (self .patch_size , int ) else self .patch_size
361- state_dict ["pos_embed" ] = resize_positional_embeddings (
362- state_dict .pop ("pos_embed" )[:, 1 :],
363- (img_size [0 ] // patch_size [0 ], img_size [1 ] // patch_size [1 ]),
364- )
377+ patch_size = (self .patch_size , self .patch_size )
378+ if state_dict ["pos_embed" ].shape != self .pos_embed .shape :
379+ state_dict ["pos_embed" ] = resize_positional_embeddings (
380+ state_dict .pop ("pos_embed" )[:, 1 :],
381+ (img_size [0 ] // patch_size [0 ], img_size [1 ] // patch_size [1 ]),
382+ )
365383 self .load_state_dict (state_dict , strict = False )
366384 else :
367385 msg = f"Unsupported `checkpoint_extension` { checkpoint_ext } , please choose from 'npz' or 'pth'."
@@ -401,6 +419,137 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
401419
402420 return self .pos_drop (x )
403421
422+ def interpolate_pos_encoding (self , x : torch .Tensor , w : int , h : int ) -> torch .Tensor :
423+ """Interpolates the positional encoding to match the input dimensions.
424+
425+ Args:
426+ x (torch.Tensor): Input tensor.
427+ w (int): Width of the input image.
428+ h (int): Height of the input image.
429+
430+ Returns:
431+ torch.Tensor: Tensor with interpolated positional encoding.
432+ """
433+ previous_dtype = x .dtype
434+ npatch = x .shape [1 ]
435+ n = self .pos_embed .shape [1 ]
436+ if npatch == n and w == h :
437+ return self .pos_embed
438+ pos_embed = self .pos_embed .float ()
439+ class_pos_embed = pos_embed [:, 0 ]
440+ patch_pos_embed = pos_embed [:, 1 :]
441+ dim = x .shape [- 1 ]
442+ w0 = w // self .patch_size
443+ h0 = h // self .patch_size
444+ m = int (math .sqrt (n )) # Recover the number of patches in each dimension
445+ if m * m != n :
446+ msg = f"Expected m * m to equal n, but got m={ m } , n={ n } "
447+ raise ValueError (msg )
448+ kwargs = {}
449+ if self .interpolate_offset :
450+ # fix float error by introducing small offset
451+ sx = float (w0 + self .interpolate_offset ) / m
452+ sy = float (h0 + self .interpolate_offset ) / m
453+ kwargs ["scale_factor" ] = (sx , sy )
454+ else :
455+ # Simply specify an output size instead of a scale factor
456+ kwargs ["size" ] = (w0 , h0 )
457+ patch_pos_embed = nn .functional .interpolate (
458+ patch_pos_embed .reshape (1 , m , m , dim ).permute (0 , 3 , 1 , 2 ),
459+ mode = "bicubic" ,
460+ ** kwargs ,
461+ )
462+ patch_pos_embed = patch_pos_embed .permute (0 , 2 , 3 , 1 ).view (1 , - 1 , dim )
463+ return torch .cat ((class_pos_embed .unsqueeze (0 ), patch_pos_embed ), dim = 1 ).to (previous_dtype )
464+
465+ def prepare_tokens_with_masks (self , x : torch .Tensor , masks : torch .Tensor | None = None ) -> torch .Tensor :
466+ """Prepare tokens with optional masks.
467+
468+ Args:
469+ x (torch.Tensor): Input tensor.
470+ masks (torch.Tensor | None): Optional masks tensor.
471+
472+ Returns:
473+ torch.Tensor: Tensor with prepared tokens.
474+ """
475+ _ , _ , w , h = x .shape
476+ x = self .patch_embed (x )
477+ if masks is not None :
478+ x = torch .where (masks .unsqueeze (- 1 ), self .mask_token .to (x .dtype ).unsqueeze (0 ), x )
479+
480+ x = torch .cat ((self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ), dim = 1 )
481+ x = x + self .interpolate_pos_encoding (x , w , h )
482+
483+ if self .reg_token is not None :
484+ x = torch .cat (
485+ (
486+ x [:, :1 ],
487+ self .reg_token .expand (x .shape [0 ], - 1 , - 1 ),
488+ x [:, 1 :],
489+ ),
490+ dim = 1 ,
491+ )
492+
493+ return x
494+
495+ def _get_intermediate_layers_not_chunked (self , x : torch .Tensor , n : int = 1 ) -> list [torch .Tensor ]:
496+ """Get intermediate layers without chunking.
497+
498+ Args:
499+ x (torch.Tensor): Input tensor.
500+ n (int): Number of last blocks to take. If it's a list, take the specified blocks.
501+
502+ Returns:
503+ list[torch.Tensor]: List of intermediate layer outputs.
504+ """
505+ x = self .prepare_tokens_with_masks (x )
506+ # If n is an int, take the n last blocks. If it's a list, take them
507+ output , total_block_len = [], len (self .blocks )
508+ blocks_to_take = range (total_block_len - n , total_block_len ) if isinstance (n , int ) else n
509+ for i , blk in enumerate (self .blocks ):
510+ x = blk (x )
511+ if i in blocks_to_take :
512+ output .append (x )
513+ if len (output ) != len (blocks_to_take ):
514+ msg = f"only { len (output )} / { len (blocks_to_take )} blocks found"
515+ raise RuntimeError (msg )
516+ return output
517+
518+ def get_intermediate_layers (
519+ self ,
520+ x : torch .Tensor ,
521+ n : int = 1 , # Layers or n last layers to take
522+ reshape : bool = False ,
523+ return_class_token : bool = False ,
524+ norm : bool = True ,
525+ ) -> tuple :
526+ """Get intermediate layers of the VisionTransformer.
527+
528+ Args:
529+ x (torch.Tensor): Input tensor.
530+ n (int): Number of last blocks to take. If it's a list, take the specified blocks.
531+ reshape (bool): Whether to reshape the output feature maps.
532+ return_class_token (bool): Whether to return the class token.
533+ norm (bool): Whether to apply normalization to the outputs.
534+
535+ Returns:
536+ tuple: A tuple containing the intermediate layer outputs.
537+ """
538+ outputs = self ._get_intermediate_layers_not_chunked (x , n )
539+ if norm :
540+ outputs = [self .norm (out ) for out in outputs ]
541+ class_tokens = [out [:, 0 ] for out in outputs ]
542+ outputs = [out [:, 1 + self .num_reg_tokens :] for out in outputs ]
543+ if reshape :
544+ b , _ , w , h = x .shape
545+ outputs = [
546+ out .reshape (b , w // self .patch_size , h // self .patch_size , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
547+ for out in outputs
548+ ]
549+ if return_class_token :
550+ return tuple (zip (outputs , class_tokens ))
551+ return tuple (outputs )
552+
404553 def forward (
405554 self ,
406555 x : torch .Tensor ,
0 commit comments