Skip to content

Commit a040191

Browse files
committed
add forward_intermediates support
1 parent 9dbb47d commit a040191

File tree

2 files changed

+102
-11
lines changed

2 files changed

+102
-11
lines changed

tests/test_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@
5353
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5454
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5555
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
56-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
56+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'pe'
5757
]
5858

5959
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
6060
NON_STD_FILTERS = [
6161
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6262
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*',
6363
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
64-
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', 'pe_*'
64+
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
6565
]
6666
NUM_NON_STD = len(NON_STD_FILTERS)
6767

@@ -224,6 +224,7 @@ def test_model_backward(model_name, batch_size):
224224
timm.models.MobileNetV3,
225225
timm.models.RepGhostNet,
226226
timm.models.VGG,
227+
timm.models.pe,
227228
)
228229

229230
@pytest.mark.cfg

timm/models/pe.py

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -393,15 +393,21 @@ def __init__(
393393
self.layers = layers
394394
self.in_chans = in_chans
395395

396-
self.num_intermediate_features = width # the dim before PE projection layer (vit output)
397-
self.proj_dim = output_dim # the output_dim after PE projection layer
396+
# PE contains an (optional) projection layer
397+
# Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
398+
# forward_features: x -> Transfomer(x)
399+
# forward_head: pool -> proj -> head
400+
# output_dim is the final output dim of the model (keep it for clarity)
398401
self.use_proj = use_proj
399402
if self.use_proj:
403+
self.proj_dim = output_dim
400404
self.head_hidden_size = self.proj_dim
401-
self.num_features = self.proj_dim
405+
self.num_features = width # self.proj_dim
402406
else:
403-
self.head_hidden_size = self.num_intermediate_features
404-
self.num_features = self.num_intermediate_features
407+
self.proj_dim = 0
408+
assert output_dim == width
409+
self.head_hidden_size = width
410+
self.num_features = width
405411

406412
self.num_classes = num_classes
407413

@@ -446,6 +452,9 @@ def __init__(
446452
rope=self.rope,
447453
)
448454

455+
self.feature_info = [
456+
dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)]
457+
449458
if pool_type == "attn":
450459
self.attn_pool = AttentionPooling(
451460
embed_dim=width,
@@ -491,7 +500,7 @@ def init_submodule_tensors(module):
491500
else: # no projection (eg PE-lang and PE-spatial)
492501
self.proj = nn.Identity()
493502
if self.num_classes > 0:
494-
self.head = nn.Linear(self.width, self.num_classes)
503+
self.head = nn.Linear(self.width, self.num_classes) # no proj. input dim = self.width (pooled)
495504
else:
496505
self.head = nn.Identity()
497506

@@ -518,6 +527,9 @@ def forward_pool_and_proj(self, x: torch.Tensor):
518527
return x
519528

520529
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
530+
# PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
531+
# Ideally pool To discuss with Ross where to split
532+
x = self.forward_pool_and_proj(x)
521533
return x if pre_logits else self.head(x)
522534

523535
#def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
@@ -544,8 +556,6 @@ def forward_features(self, x: torch.Tensor, norm: bool = False):
544556

545557
# if strip_cls_token and self.use_cls_token:
546558
# x = x[:, 1:, :]
547-
548-
x = self.forward_pool_and_proj(x)
549559
return x
550560

551561
def forward(self, x: torch.Tensor):
@@ -563,6 +573,86 @@ def reset_classifier(self, num_classes: int):
563573
else:
564574
self.head = nn.Identity()
565575

576+
def forward_intermediates(
577+
self,
578+
x: torch.Tensor,
579+
indices: Optional[Union[int, List[int]]] = None,
580+
return_prefix_tokens: bool = False,
581+
norm: bool = False,
582+
stop_early: bool = False,
583+
output_fmt: str = 'NCHW',
584+
intermediates_only: bool = False,
585+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
586+
""" Forward features that returns intermediates.
587+
588+
Args:
589+
x: Input image tensor
590+
indices: Take last n blocks if int, all if None, select matching indices if sequence
591+
return_prefix_tokens: Return both prefix and spatial intermediate tokens
592+
norm: Apply norm layer to all intermediates
593+
stop_early: Stop iterating over blocks when last desired intermediate hit
594+
output_fmt: Shape of intermediate feature outputs
595+
intermediates_only: Only return intermediate features
596+
Returns:
597+
598+
"""
599+
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
600+
reshape = output_fmt == 'NCHW'
601+
intermediates = []
602+
take_indices, max_index = feature_take_indices(self.layers, indices)
603+
604+
# forward pass
605+
B, _, height, width = x.shape
606+
607+
608+
x = self.conv1(x)
609+
x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC
610+
611+
if self.use_cls_token:
612+
x = torch.cat(
613+
[self.class_embedding.view(1, 1, -1).expand(B, -1, -1), x],
614+
dim=1,
615+
)
616+
617+
if self.use_abs_posemb:
618+
x = x + self.positional_embedding[None, ...]
619+
620+
x = self.ln_pre(x)
621+
622+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
623+
blocks = self.transformer.resblocks
624+
else:
625+
blocks = self.transformer.resblocks[:max_index + 1]
626+
627+
for i, blk in enumerate(blocks):
628+
x = blk(x)
629+
if i in take_indices:
630+
# normalize intermediates with final norm layer if enabled
631+
intermediates.append(self.norm(x) if norm else x)
632+
633+
# process intermediates
634+
if self.use_cls_token:
635+
prefix_tokens = [y[:, 0] for y in intermediates]
636+
intermediates = [y[:, 1:] for y in intermediates]
637+
else:
638+
prefix_tokens = None
639+
640+
if reshape:
641+
# reshape to BCHW output format
642+
H = W = self.posemb_grid_size
643+
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
644+
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
645+
# return_prefix not support in torchscript due to poor type handling
646+
intermediates = list(zip(intermediates, prefix_tokens))
647+
648+
if intermediates_only:
649+
return intermediates
650+
651+
x = self.ln_post(x)
652+
653+
return x, intermediates
654+
655+
566656

567657
def checkpoint_filter_fn(
568658
state_dict: Dict[str, torch.Tensor],
@@ -679,7 +769,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
679769
mlp_ratio=4.0,
680770
output_dim=1024,
681771
num_classes=0,
682-
use_cls_token=False,
772+
use_cls_token=True,
683773
use_ln_post=False,
684774
pool_type='none',
685775
ls_init_value=0.1,

0 commit comments

Comments
 (0)