@@ -281,7 +281,6 @@ def forward(self, t: Tensor, seq_len=None, offset=0):
281281
282282 return freqs
283283
284-
285284class Rope2D :
286285 """Helper class to apply RoPE2D as well as interpolate on the fly."""
287286
@@ -565,14 +564,14 @@ def __init__(
565564 use_ln_post : bool = True ,
566565 ls_init_value : float = None ,
567566 drop_path : float = 0.0 ,
568- image_size : int = 448 , # Pretrain image size only; you can pass in any image size
567+ img_size : int = 448 , # Pretrain image size only; you can pass in any image size
569568 use_abs_posemb : bool = True ,
570569 use_rope2d : bool = True ,
571570 use_cls_token : bool = False ,
572571 output_dim : Optional [int ] = 1280 ,
573572 attn_pooler_heads : int = 8 ,
574573 pool_type : Literal ["attn" , "tok" , "avg" , "none" ] = "attn" ,
575- num_classes : int = 1000 , # no use for now
574+ num_classes : int = 0 , # no use for PE
576575 in_chans : int = 3 ,
577576 ):
578577 super ().__init__ ()
@@ -589,7 +588,9 @@ def __init__(
589588 self .use_abs_posemb = use_abs_posemb
590589 self .use_cls_token = use_cls_token
591590 self .use_rope2d = use_rope2d
592- self .image_size = image_size
591+ if isinstance (img_size , (tuple , list )):
592+ img_size = img_size [0 ]
593+ self .img_size = img_size
593594
594595 self .conv1 = nn .Conv2d (
595596 in_channels = 3 ,
@@ -652,7 +653,7 @@ def init_submodule_tensors(module):
652653 self .class_embedding = nn .Parameter (init_scale * torch .randn (self .width ))
653654
654655 if self .use_abs_posemb :
655- self .posemb_grid_size = self .image_size // self .patch_size
656+ self .posemb_grid_size = self .img_size // self .patch_size
656657 self .positional_embedding = nn .Parameter (
657658 init_scale * torch .randn (int (self .use_cls_token ) + self .posemb_grid_size ** 2 , self .width )
658659 )
@@ -731,8 +732,8 @@ def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int =
731732
732733 return x
733734
734- def forward (self , x : torch .Tensor , ** kwargs ):
735- x = self .forward_features (x , norm = True , ** kwargs )
735+ def forward (self , x : torch .Tensor , layer_idx : int = - 1 , strip_cls_token : bool = False ):
736+ x = self .forward_features (x , norm = True , layer_idx = layer_idx , strip_cls_token = strip_cls_token )
736737 x = self ._pool (x )
737738
738739 if self .proj_dim is not None :
@@ -758,8 +759,8 @@ def _cfg(url='', **kwargs):
758759 'num_classes' : 0 ,
759760 'interpolation' : 'bilinear' ,
760761 'fixed_input_size' : True ,
761- 'mean' : IMAGENET_INCEPTION_MEAN ,
762- 'std' : IMAGENET_INCEPTION_STD ,
762+ 'mean' : IMAGENET_INCEPTION_MEAN , # (0.5, 0.5, 0.5)
763+ 'std' : IMAGENET_INCEPTION_STD , # (0.5, 0.5, 0.5)
763764 ** kwargs ,
764765 }
765766
@@ -792,7 +793,7 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
792793@register_model
793794def vit_pe_core_base_patch16_224 (pretrained = False , ** kwargs ):
794795 model_args = dict (
795- image_size = 224 ,
796+ img_size = 224 ,
796797 patch_size = 16 ,
797798 width = 768 ,
798799 layers = 12 ,
@@ -808,7 +809,7 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
808809@register_model
809810def vit_pe_core_large_patch14_336 (pretrained = False , ** kwargs ):
810811 model_args = dict (
811- image_size = 336 ,
812+ img_size = 336 ,
812813 patch_size = 14 ,
813814 width = 1024 ,
814815 layers = 24 ,
@@ -824,7 +825,7 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
824825@register_model
825826def vit_pe_core_gigantic_patch14_448 (pretrained = False , ** kwargs ):
826827 model_args = dict (
827- image_size = 448 ,
828+ img_size = 448 ,
828829 patch_size = 14 ,
829830 width = 1536 ,
830831 layers = 50 ,
@@ -840,7 +841,7 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
840841@register_model
841842def vit_pe_lang_large_patch14_448 (pretrained = False , ** kwargs ):
842843 model_args = dict (
843- image_size = 448 ,
844+ img_size = 448 ,
844845 patch_size = 14 ,
845846 width = 1024 ,
846847 layers = 23 ,
@@ -858,7 +859,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
858859@register_model
859860def vit_pe_lang_gigantic_patch14_448 (pretrained = False , ** kwargs ):
860861 model_args = dict (
861- image_size = 448 ,
862+ img_size = 448 ,
862863 patch_size = 14 ,
863864 width = 1536 ,
864865 layers = 47 ,
@@ -876,7 +877,7 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
876877@register_model
877878def vit_pe_spatial_gigantic_patch14_448 (pretrained = False , ** kwargs ):
878879 model_args = dict (
879- image_size = 448 ,
880+ img_size = 448 ,
880881 patch_size = 14 ,
881882 width = 1536 ,
882883 layers = 50 ,
0 commit comments