22
33import math
44from functools import partial
5- from typing import Any , Callable , Optional , Union , cast
5+ from typing import Any , Callable , Literal , Optional , Union , cast
66
77import timm
88import torch
99from hydra_zen import store
1010from peft import PeftConfig
1111from timm .models .vision_transformer import VisionTransformer as TimmVisionTransformer
12+ from timm .models .vision_transformer import global_pool_nlc
1213from torch import nn
1314from transformers .modeling_outputs import BaseModelOutput
1415
@@ -33,6 +34,9 @@ class TimmViT(nn.Module):
3334 ----------
3435 model_name : str
3536 The name of the model to use.
37+ modality : str, default="RGB"
38+ The modality of the input data. This allows this model to be used with different
39+ image modalities e.g. RGB, Depth, etc.
3640 projection_dim : int, default=768
3741 The dimension of the projection head.
3842 pretrained : bool, default=True
@@ -51,6 +55,7 @@ class TimmViT(nn.Module):
5155 def __init__ (
5256 self ,
5357 model_name : str ,
58+ modality : str = "RGB" ,
5459 projection_dim : int = 768 ,
5560 pretrained : bool = True ,
5661 freeze_layers : Union [int , float , list [int ], bool ] = False ,
@@ -59,6 +64,7 @@ def __init__(
5964 model_kwargs : Optional [dict [str , Any ]] = None ,
6065 ) -> None :
6166 super ().__init__ ()
67+ self .modality = Modalities .get_modality (modality )
6268 if model_kwargs is None :
6369 model_kwargs = {}
6470
@@ -124,7 +130,7 @@ def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
124130 BaseModelOutput
125131 The output of the model.
126132 """
127- x = inputs [Modalities . RGB .name ]
133+ x = inputs [self . modality .name ]
128134 last_hidden_state , hidden_states = self .model .forward_intermediates (
129135 x , output_fmt = "NLC"
130136 )
@@ -175,8 +181,11 @@ class VisionTransformer(nn.Module):
175181
176182 Parameters
177183 ----------
178- img_size : Optional[list[int]], optional, default=None
179- list of input image sizes.
184+ modality : str, optional, default="RGB"
185+ The modality of the input data. This allows this model to be used with different
186+ image modalities e.g. RGB, Depth, etc.
187+ img_size : List[int], optional, default=None
188+ List of input image sizes.
180189 patch_size : int, optional, default=16
181190 Size of each patch.
182191 in_chans : int, optional, default=3
@@ -209,6 +218,7 @@ class VisionTransformer(nn.Module):
209218
210219 def __init__ (
211220 self ,
221+ modality : str = "RGB" ,
212222 img_size : Optional [list [int ]] = None ,
213223 patch_size : int = 16 ,
214224 in_chans : int = 3 ,
@@ -218,14 +228,17 @@ def __init__(
218228 mlp_ratio : float = 4.0 ,
219229 qkv_bias : bool = True ,
220230 qk_scale : Optional [float ] = None ,
231+ global_pool : Literal ["" , "avg" , "avgmax" , "max" , "token" ] = "" ,
221232 drop_rate : float = 0.0 ,
222233 attn_drop_rate : float = 0.0 ,
223234 drop_path_rate : float = 0.0 ,
224235 norm_layer : Callable [..., nn .Module ] = nn .LayerNorm ,
225236 init_std : float = 0.02 ,
226- ** kwargs : Any ,
227237 ) -> None :
228238 super ().__init__ ()
239+ assert global_pool in ("" , "avg" , "avgmax" , "max" , "token" )
240+
241+ self .modality = Modalities .get_modality (modality )
229242 self .num_features = self .embed_dim = embed_dim
230243 self .num_heads = num_heads
231244 img_size = [224 , 224 ] if img_size is None else img_size
@@ -272,6 +285,8 @@ def __init__(
272285 )
273286 self .norm = norm_layer (embed_dim )
274287
288+ self .global_pool = global_pool
289+
275290 # Weight Initialization
276291 self .init_std = init_std
277292 self .apply (self ._init_weights )
@@ -301,15 +316,14 @@ def _init_weights(self, m: nn.Module) -> None:
301316 nn .init .constant_ (m .bias , 0 )
302317
303318 def forward (
304- self ,
305- x : torch .Tensor ,
306- masks : Optional [Union [torch .Tensor , list [torch .Tensor ]]] = None ,
307- return_hidden_states : bool = False ,
308- ) -> Union [torch .Tensor , tuple [torch .Tensor , list [torch .Tensor ]]]:
319+ self , inputs : dict [str , Any ], return_hidden_states : bool = False
320+ ) -> tuple [torch .Tensor , Optional [list [torch .Tensor ]]]:
309321 """Forward pass through the Vision Transformer."""
322+ masks = inputs .get (self .modality .mask )
310323 if masks is not None and not isinstance (masks , list ):
311324 masks = [masks ]
312325
326+ x = inputs [self .modality .name ]
313327 # -- Patchify x
314328 x = self .patch_embed (x )
315329
@@ -336,10 +350,13 @@ def forward(
336350 if self .norm is not None :
337351 x = self .norm (x )
338352
353+ # -- Apply global pooling
354+ x = global_pool_nlc (x , pool_type = self .global_pool )
355+
339356 # -- Return both final output and hidden states if requested
340357 if return_hidden_states :
341358 return x , hidden_states
342- return x
359+ return ( x , None )
343360
344361 def interpolate_pos_encoding (
345362 self , x : torch .Tensor , pos_embed : torch .Tensor
@@ -586,8 +603,7 @@ def _trunc_normal(
586603def _no_grad_trunc_normal_ (
587604 tensor : torch .Tensor , mean : float , std : float , a : float , b : float
588605) -> torch .Tensor :
589- """
590- Apply truncated normal initialization to a tensor.
606+ """Apply truncated normal initialization to a tensor.
591607
592608 Parameters
593609 ----------
@@ -633,15 +649,23 @@ def norm_cdf(x: float) -> float:
633649 provider = "mmlearn" ,
634650 ),
635651)
636- def vit_predictor (** kwargs : Any ) -> VisionTransformerPredictor :
637- """
638- Create a VisionTransformerPredictor model.
652+ def vit_predictor (
653+ kwargs : Optional [dict [str , Any ]] = None ,
654+ ) -> VisionTransformerPredictor :
655+ """Create a VisionTransformerPredictor model.
656+
657+ Parameters
658+ ----------
659+ kwargs : dict[str, Any], optional, default=None
660+ Keyword arguments for the predictor.
639661
640662 Returns
641663 -------
642664 VisionTransformerPredictor
643665 An instance of VisionTransformerPredictor.
644666 """
667+ if kwargs is None :
668+ kwargs = {}
645669 return VisionTransformerPredictor (
646670 mlp_ratio = 4 , qkv_bias = True , norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs
647671 )
@@ -654,15 +678,25 @@ def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
654678 provider = "mmlearn" ,
655679 ),
656680)
657- def vit_tiny (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
658- """
659- Create a VisionTransformer model with tiny configuration.
681+ def vit_tiny (
682+ patch_size : int = 16 , kwargs : Optional [dict [str , Any ]] = None
683+ ) -> VisionTransformer :
684+ """Create a VisionTransformer model with tiny configuration.
685+
686+ Parameters
687+ ----------
688+ patch_size : int, default=16
689+ Size of each patch.
690+ kwargs : dict[str, Any], optional, default=None
691+ Keyword arguments for the model variant.
660692
661693 Returns
662694 -------
663695 VisionTransformer
664696 An instance of VisionTransformer.
665697 """
698+ if kwargs is None :
699+ kwargs = {}
666700 return VisionTransformer (
667701 patch_size = patch_size ,
668702 embed_dim = 192 ,
@@ -682,15 +716,25 @@ def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
682716 provider = "mmlearn" ,
683717 ),
684718)
685- def vit_small (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
686- """
687- Create a VisionTransformer model with small configuration.
719+ def vit_small (
720+ patch_size : int = 16 , kwargs : Optional [dict [str , Any ]] = None
721+ ) -> VisionTransformer :
722+ """Create a VisionTransformer model with small configuration.
723+
724+ Parameters
725+ ----------
726+ patch_size : int, default=16
727+ Size of each patch.
728+ kwargs : dict[str, Any], optional, default=None
729+ Keyword arguments for the model variant.
688730
689731 Returns
690732 -------
691733 VisionTransformer
692734 An instance of VisionTransformer.
693735 """
736+ if kwargs is None :
737+ kwargs = {}
694738 return VisionTransformer (
695739 patch_size = patch_size ,
696740 embed_dim = 384 ,
@@ -710,15 +754,25 @@ def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
710754 provider = "mmlearn" ,
711755 ),
712756)
713- def vit_base (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
714- """
715- Create a VisionTransformer model with base configuration.
757+ def vit_base (
758+ patch_size : int = 16 , kwargs : Optional [dict [str , Any ]] = None
759+ ) -> VisionTransformer :
760+ """Create a VisionTransformer model with base configuration.
761+
762+ Parameters
763+ ----------
764+ patch_size : int, default=16
765+ Size of each patch.
766+ kwargs : dict[str, Any], optional, default=None
767+ Keyword arguments for the model variant.
716768
717769 Returns
718770 -------
719771 VisionTransformer
720772 An instance of VisionTransformer.
721773 """
774+ if kwargs is None :
775+ kwargs = {}
722776 return VisionTransformer (
723777 patch_size = patch_size ,
724778 embed_dim = 768 ,
@@ -738,15 +792,25 @@ def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
738792 provider = "mmlearn" ,
739793 ),
740794)
741- def vit_large (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
742- """
743- Create a VisionTransformer model with large configuration.
795+ def vit_large (
796+ patch_size : int = 16 , kwargs : Optional [dict [str , Any ]] = None
797+ ) -> VisionTransformer :
798+ """Create a VisionTransformer model with large configuration.
799+
800+ Parameters
801+ ----------
802+ patch_size : int, default=16
803+ Size of each patch.
804+ kwargs : dict[str, Any], optional, default=None
805+ Keyword arguments for the model variant.
744806
745807 Returns
746808 -------
747809 VisionTransformer
748810 An instance of VisionTransformer.
749811 """
812+ if kwargs is None :
813+ kwargs = {}
750814 return VisionTransformer (
751815 patch_size = patch_size ,
752816 embed_dim = 1024 ,
@@ -766,15 +830,25 @@ def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
766830 provider = "mmlearn" ,
767831 ),
768832)
769- def vit_huge (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
770- """
771- Create a VisionTransformer model with huge configuration.
833+ def vit_huge (
834+ patch_size : int = 16 , kwargs : Optional [dict [str , Any ]] = None
835+ ) -> VisionTransformer :
836+ """Create a VisionTransformer model with huge configuration.
837+
838+ Parameters
839+ ----------
840+ patch_size : int, default=16
841+ Size of each patch.
842+ kwargs : dict[str, Any], optional, default=None
843+ Keyword arguments for the model variant.
772844
773845 Returns
774846 -------
775847 VisionTransformer
776848 An instance of VisionTransformer.
777849 """
850+ if kwargs is None :
851+ kwargs = {}
778852 return VisionTransformer (
779853 patch_size = patch_size ,
780854 embed_dim = 1280 ,
@@ -794,15 +868,25 @@ def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
794868 provider = "mmlearn" ,
795869 ),
796870)
797- def vit_giant (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
798- """
799- Create a VisionTransformer model with giant configuration.
871+ def vit_giant (
872+ patch_size : int = 16 , kwargs : Optional [dict [str , Any ]] = None
873+ ) -> VisionTransformer :
874+ """Create a VisionTransformer model with giant configuration.
875+
876+ Parameters
877+ ----------
878+ patch_size : int, default=16
879+ Size of each patch.
880+ kwargs : dict[str, Any], optional, default=None
881+ Keyword arguments for the model variant.
800882
801883 Returns
802884 -------
803885 VisionTransformer
804886 An instance of VisionTransformer.
805887 """
888+ if kwargs is None :
889+ kwargs = {}
806890 return VisionTransformer (
807891 patch_size = patch_size ,
808892 embed_dim = 1408 ,
0 commit comments