Skip to content

Commit 414b775

Browse files
committed
torchscript for L/G models at higher resolution
1 parent a040191 commit 414b775

File tree

1 file changed

+36
-46
lines changed

1 file changed

+36
-46
lines changed

timm/models/pe.py

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,10 @@ def __init__(
246246
):
247247
super().__init__()
248248

249-
if rope:
250-
self.attn = SelfAttention(d_model, n_head, rope=rope)
251-
else:
252-
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
249+
#if rope:
250+
self.attn = SelfAttention(d_model, n_head, rope=rope)
251+
#else:
252+
# self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
253253

254254
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
255255
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
@@ -281,10 +281,10 @@ def _call_attn(
281281
if not attn_mask.dtype == torch.bool:
282282
attn_mask = attn_mask.to(q_x.dtype)
283283

284-
if isinstance(self.attn, SelfAttention):
285-
return self.attn(q_x, attn_mask=attn_mask)
286-
else:
287-
return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
284+
#if isinstance(self.attn, SelfAttention):
285+
return self.attn(q_x, attn_mask=attn_mask)
286+
#else:
287+
# return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
288288

289289
def forward(
290290
self,
@@ -380,19 +380,16 @@ def __init__(
380380
output_dim: Optional[int] = 1280,
381381
num_classes: int = 0,
382382
attn_pooler_heads: int = 8,
383-
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
383+
use_attn_pool: bool = True,
384384
in_chans: int = 3,
385385
):
386386
super().__init__()
387-
assert pool_type in ("attn", "tok", "avg", "none")
388-
self.pool_type = pool_type
389-
390387
self.patch_size = patch_size
391388
self.heads = heads
392389
self.width = width
393390
self.layers = layers
394391
self.in_chans = in_chans
395-
392+
396393
# PE contains an (optional) projection layer
397394
# Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
398395
# forward_features: x -> Transfomer(x)
@@ -418,6 +415,7 @@ def __init__(
418415
if isinstance(img_size, (tuple, list)):
419416
img_size = img_size[0]
420417
self.img_size = img_size
418+
self.grid_size = self.img_size // self.patch_size
421419

422420
self.conv1 = nn.Conv2d(
423421
in_channels=in_chans,
@@ -455,7 +453,7 @@ def __init__(
455453
self.feature_info = [
456454
dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)]
457455

458-
if pool_type == "attn":
456+
if use_attn_pool:
459457
self.attn_pool = AttentionPooling(
460458
embed_dim=width,
461459
num_heads=attn_pooler_heads,
@@ -483,12 +481,15 @@ def init_submodule_tensors(module):
483481

484482
if self.use_cls_token:
485483
self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
484+
else:
485+
self.class_embedding = None
486486

487487
if self.use_abs_posemb:
488-
self.posemb_grid_size = self.img_size // self.patch_size
489488
self.positional_embedding = nn.Parameter(
490-
init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width)
489+
init_scale * torch.randn(int(self.use_cls_token) + self.grid_size**2, self.width)
491490
)
491+
else:
492+
self.positional_embedding = None
492493

493494
# PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer)
494495
if self.use_proj:
@@ -498,7 +499,7 @@ def init_submodule_tensors(module):
498499
else:
499500
self.head = nn.Identity()
500501
else: # no projection (eg PE-lang and PE-spatial)
501-
self.proj = nn.Identity()
502+
self.proj = None
502503
if self.num_classes > 0:
503504
self.head = nn.Linear(self.width, self.num_classes) # no proj. input dim = self.width (pooled)
504505
else:
@@ -514,15 +515,9 @@ def set_grad_checkpointing(self, enable=True):
514515
self.transformer.set_grad_checkpointing(enable=enable)
515516

516517
def forward_pool_and_proj(self, x: torch.Tensor):
517-
if self.pool_type == "tok":
518-
x = x[:, 0]
519-
elif self.pool_type == "avg":
520-
x = x.mean(dim=1)
521-
elif self.pool_type == "attn":
518+
if self.attn_pool is not None:
522519
x = self.attn_pool(x).squeeze(1)
523-
elif self.pool_type == "none":
524-
x = x
525-
if self.use_proj:
520+
if self.proj is not None:
526521
x = x @ self.proj
527522
return x
528523

@@ -532,30 +527,26 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
532527
x = self.forward_pool_and_proj(x)
533528
return x if pre_logits else self.head(x)
534529

535-
#def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
536530
def forward_features(self, x: torch.Tensor, norm: bool = False):
537-
#: layer_idx = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported in timm as in orig pe
538531
batch, _, h, w = x.shape
539532

540533
x = self.conv1(x)
541534
x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
542535

543-
if self.use_cls_token:
536+
if self.class_embedding is not None:
544537
x = torch.cat(
545538
[self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
546539
dim=1,
547540
)
548-
549-
if self.use_abs_posemb:
541+
542+
if self.positional_embedding is not None:
550543
x = x + self.positional_embedding[None, ...]
551544

552545
x = self.ln_pre(x)
553546
x = self.transformer(x)
554547
if norm:
555548
x = self.ln_post(x)
556549

557-
# if strip_cls_token and self.use_cls_token:
558-
# x = x[:, 1:, :]
559550
return x
560551

561552
def forward(self, x: torch.Tensor):
@@ -566,7 +557,7 @@ def forward(self, x: torch.Tensor):
566557
def reset_classifier(self, num_classes: int):
567558
self.num_classes = num_classes
568559
if num_classes > 0:
569-
if self.proj_dim > 0:
560+
if self.proj is not None:
570561
self.head = nn.Parameter(self.proj_dim, num_classes)
571562
else: # no projection (eg PE-lang and PE-spatial)
572563
self.head = nn.Parameter(self.width, num_classes)
@@ -603,18 +594,17 @@ def forward_intermediates(
603594

604595
# forward pass
605596
B, _, height, width = x.shape
606-
607-
597+
# patch embedgging
608598
x = self.conv1(x)
609599
x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC
610600

611-
if self.use_cls_token:
601+
if self.class_embedding is not None:
612602
x = torch.cat(
613603
[self.class_embedding.view(1, 1, -1).expand(B, -1, -1), x],
614604
dim=1,
615605
)
616606

617-
if self.use_abs_posemb:
607+
if self.positional_embedding is not None:
618608
x = x + self.positional_embedding[None, ...]
619609

620610
x = self.ln_pre(x)
@@ -631,15 +621,15 @@ def forward_intermediates(
631621
intermediates.append(self.norm(x) if norm else x)
632622

633623
# process intermediates
634-
if self.use_cls_token:
635-
prefix_tokens = [y[:, 0] for y in intermediates]
624+
if self.class_embedding is not None:
625+
prefix_tokens = [y[:, 0] for y in intermediates] # only one cls token in PE
636626
intermediates = [y[:, 1:] for y in intermediates]
637627
else:
638628
prefix_tokens = None
639629

640630
if reshape:
641631
# reshape to BCHW output format
642-
H = W = self.posemb_grid_size
632+
H = W = self.grid_size
643633
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
644634
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
645635
# return_prefix not support in torchscript due to poor type handling
@@ -716,7 +706,7 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
716706
output_dim=1024,
717707
num_classes=0,
718708
use_cls_token=True,
719-
pool_type='attn',
709+
use_attn_pool=True,
720710
use_proj=True,
721711
)
722712
return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
@@ -734,7 +724,7 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
734724
output_dim=1024,
735725
num_classes=0,
736726
use_cls_token=True,
737-
pool_type='attn',
727+
use_attn_pool=True,
738728
use_proj=True,
739729
)
740730
return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
@@ -752,7 +742,7 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
752742
output_dim=1280,
753743
num_classes=0,
754744
use_cls_token=False,
755-
pool_type='attn',
745+
use_attn_pool=True,
756746
use_proj=True,
757747
)
758748
return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@@ -771,7 +761,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
771761
num_classes=0,
772762
use_cls_token=True,
773763
use_ln_post=False,
774-
pool_type='none',
764+
use_attn_pool=False,
775765
ls_init_value=0.1,
776766
use_proj=False,
777767
)
@@ -791,7 +781,7 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
791781
num_classes=0,
792782
use_cls_token=False,
793783
use_ln_post=False,
794-
pool_type='none',
784+
use_attn_pool=False,
795785
ls_init_value=0.1,
796786
use_proj=False,
797787
)
@@ -811,7 +801,7 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
811801
num_classes=0,
812802
use_cls_token=False,
813803
use_ln_post=False,
814-
pool_type='none',
804+
use_attn_pool=False,
815805
ls_init_value=0.1,
816806
use_proj=False,
817807
)

0 commit comments

Comments
 (0)