Skip to content

Commit 2327ecc

Browse files
committed
fix config
1 parent c5c437a commit 2327ecc

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

timm/models/pe.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def forward(self, t: Tensor, seq_len=None, offset=0):
281281

282282
return freqs
283283

284-
285284
class Rope2D:
286285
"""Helper class to apply RoPE2D as well as interpolate on the fly."""
287286

@@ -565,14 +564,14 @@ def __init__(
565564
use_ln_post: bool = True,
566565
ls_init_value: float = None,
567566
drop_path: float = 0.0,
568-
image_size: int = 448, # Pretrain image size only; you can pass in any image size
567+
img_size: int = 448, # Pretrain image size only; you can pass in any image size
569568
use_abs_posemb: bool = True,
570569
use_rope2d: bool = True,
571570
use_cls_token: bool = False,
572571
output_dim: Optional[int] = 1280,
573572
attn_pooler_heads: int = 8,
574573
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
575-
num_classes: int = 1000, # no use for now
574+
num_classes: int = 0, # no use for PE
576575
in_chans: int = 3,
577576
):
578577
super().__init__()
@@ -589,7 +588,9 @@ def __init__(
589588
self.use_abs_posemb = use_abs_posemb
590589
self.use_cls_token = use_cls_token
591590
self.use_rope2d = use_rope2d
592-
self.image_size = image_size
591+
if isinstance(img_size, (tuple, list)):
592+
img_size = img_size[0]
593+
self.img_size = img_size
593594

594595
self.conv1 = nn.Conv2d(
595596
in_channels=3,
@@ -652,7 +653,7 @@ def init_submodule_tensors(module):
652653
self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
653654

654655
if self.use_abs_posemb:
655-
self.posemb_grid_size = self.image_size // self.patch_size
656+
self.posemb_grid_size = self.img_size // self.patch_size
656657
self.positional_embedding = nn.Parameter(
657658
init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width)
658659
)
@@ -731,8 +732,8 @@ def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int =
731732

732733
return x
733734

734-
def forward(self, x: torch.Tensor, **kwargs):
735-
x = self.forward_features(x, norm=True, **kwargs)
735+
def forward(self, x: torch.Tensor, layer_idx: int = -1, strip_cls_token: bool = False):
736+
x = self.forward_features(x, norm=True, layer_idx=layer_idx, strip_cls_token=strip_cls_token)
736737
x = self._pool(x)
737738

738739
if self.proj_dim is not None:
@@ -758,8 +759,8 @@ def _cfg(url='', **kwargs):
758759
'num_classes': 0,
759760
'interpolation': 'bilinear',
760761
'fixed_input_size': True,
761-
'mean': IMAGENET_INCEPTION_MEAN,
762-
'std': IMAGENET_INCEPTION_STD,
762+
'mean': IMAGENET_INCEPTION_MEAN, # (0.5, 0.5, 0.5)
763+
'std': IMAGENET_INCEPTION_STD, # (0.5, 0.5, 0.5)
763764
**kwargs,
764765
}
765766

@@ -792,7 +793,7 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
792793
@register_model
793794
def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
794795
model_args = dict(
795-
image_size=224,
796+
img_size=224,
796797
patch_size=16,
797798
width=768,
798799
layers=12,
@@ -808,7 +809,7 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
808809
@register_model
809810
def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
810811
model_args = dict(
811-
image_size=336,
812+
img_size=336,
812813
patch_size=14,
813814
width=1024,
814815
layers=24,
@@ -824,7 +825,7 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
824825
@register_model
825826
def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
826827
model_args = dict(
827-
image_size=448,
828+
img_size=448,
828829
patch_size=14,
829830
width=1536,
830831
layers=50,
@@ -840,7 +841,7 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
840841
@register_model
841842
def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
842843
model_args = dict(
843-
image_size=448,
844+
img_size=448,
844845
patch_size=14,
845846
width=1024,
846847
layers=23,
@@ -858,7 +859,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
858859
@register_model
859860
def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
860861
model_args = dict(
861-
image_size=448,
862+
img_size=448,
862863
patch_size=14,
863864
width=1536,
864865
layers=47,
@@ -876,7 +877,7 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
876877
@register_model
877878
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
878879
model_args = dict(
879-
image_size=448,
880+
img_size=448,
880881
patch_size=14,
881882
width=1536,
882883
layers=50,

0 commit comments

Comments
 (0)