@@ -393,15 +393,21 @@ def __init__(
393393 self .layers = layers
394394 self .in_chans = in_chans
395395
396- self .num_intermediate_features = width # the dim before PE projection layer (vit output)
397- self .proj_dim = output_dim # the output_dim after PE projection layer
396+ # PE contains an (optional) projection layer
397+ # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
398+ # forward_features: x -> Transfomer(x)
399+ # forward_head: pool -> proj -> head
400+ # output_dim is the final output dim of the model (keep it for clarity)
398401 self .use_proj = use_proj
399402 if self .use_proj :
403+ self .proj_dim = output_dim
400404 self .head_hidden_size = self .proj_dim
401- self .num_features = self .proj_dim
405+ self .num_features = width # self.proj_dim
402406 else :
403- self .head_hidden_size = self .num_intermediate_features
404- self .num_features = self .num_intermediate_features
407+ self .proj_dim = 0
408+ assert output_dim == width
409+ self .head_hidden_size = width
410+ self .num_features = width
405411
406412 self .num_classes = num_classes
407413
@@ -446,6 +452,9 @@ def __init__(
446452 rope = self .rope ,
447453 )
448454
455+ self .feature_info = [
456+ dict (module = f'blocks.{ i } ' , num_chs = width , reduction = patch_size ) for i in range (layers )]
457+
449458 if pool_type == "attn" :
450459 self .attn_pool = AttentionPooling (
451460 embed_dim = width ,
@@ -491,7 +500,7 @@ def init_submodule_tensors(module):
491500 else : # no projection (eg PE-lang and PE-spatial)
492501 self .proj = nn .Identity ()
493502 if self .num_classes > 0 :
494- self .head = nn .Linear (self .width , self .num_classes )
503+ self .head = nn .Linear (self .width , self .num_classes ) # no proj. input dim = self.width (pooled)
495504 else :
496505 self .head = nn .Identity ()
497506
@@ -518,6 +527,9 @@ def forward_pool_and_proj(self, x: torch.Tensor):
518527 return x
519528
520529 def forward_head (self , x : torch .Tensor , pre_logits : bool = False ):
530+ # PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
531+ # Ideally pool To discuss with Ross where to split
532+ x = self .forward_pool_and_proj (x )
521533 return x if pre_logits else self .head (x )
522534
523535 #def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
@@ -544,8 +556,6 @@ def forward_features(self, x: torch.Tensor, norm: bool = False):
544556
545557 # if strip_cls_token and self.use_cls_token:
546558 # x = x[:, 1:, :]
547-
548- x = self .forward_pool_and_proj (x )
549559 return x
550560
551561 def forward (self , x : torch .Tensor ):
@@ -563,6 +573,86 @@ def reset_classifier(self, num_classes: int):
563573 else :
564574 self .head = nn .Identity ()
565575
576+ def forward_intermediates (
577+ self ,
578+ x : torch .Tensor ,
579+ indices : Optional [Union [int , List [int ]]] = None ,
580+ return_prefix_tokens : bool = False ,
581+ norm : bool = False ,
582+ stop_early : bool = False ,
583+ output_fmt : str = 'NCHW' ,
584+ intermediates_only : bool = False ,
585+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
586+ """ Forward features that returns intermediates.
587+
588+ Args:
589+ x: Input image tensor
590+ indices: Take last n blocks if int, all if None, select matching indices if sequence
591+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
592+ norm: Apply norm layer to all intermediates
593+ stop_early: Stop iterating over blocks when last desired intermediate hit
594+ output_fmt: Shape of intermediate feature outputs
595+ intermediates_only: Only return intermediate features
596+ Returns:
597+
598+ """
599+ assert output_fmt in ('NCHW' , 'NLC' ), 'Output format must be one of NCHW or NLC.'
600+ reshape = output_fmt == 'NCHW'
601+ intermediates = []
602+ take_indices , max_index = feature_take_indices (self .layers , indices )
603+
604+ # forward pass
605+ B , _ , height , width = x .shape
606+
607+
608+ x = self .conv1 (x )
609+ x = x .permute (0 , 2 , 3 , 1 ).reshape (B , - 1 , self .width ) # NLC
610+
611+ if self .use_cls_token :
612+ x = torch .cat (
613+ [self .class_embedding .view (1 , 1 , - 1 ).expand (B , - 1 , - 1 ), x ],
614+ dim = 1 ,
615+ )
616+
617+ if self .use_abs_posemb :
618+ x = x + self .positional_embedding [None , ...]
619+
620+ x = self .ln_pre (x )
621+
622+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
623+ blocks = self .transformer .resblocks
624+ else :
625+ blocks = self .transformer .resblocks [:max_index + 1 ]
626+
627+ for i , blk in enumerate (blocks ):
628+ x = blk (x )
629+ if i in take_indices :
630+ # normalize intermediates with final norm layer if enabled
631+ intermediates .append (self .norm (x ) if norm else x )
632+
633+ # process intermediates
634+ if self .use_cls_token :
635+ prefix_tokens = [y [:, 0 ] for y in intermediates ]
636+ intermediates = [y [:, 1 :] for y in intermediates ]
637+ else :
638+ prefix_tokens = None
639+
640+ if reshape :
641+ # reshape to BCHW output format
642+ H = W = self .posemb_grid_size
643+ intermediates = [y .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous () for y in intermediates ]
644+ if not torch .jit .is_scripting () and return_prefix_tokens and prefix_tokens is not None :
645+ # return_prefix not support in torchscript due to poor type handling
646+ intermediates = list (zip (intermediates , prefix_tokens ))
647+
648+ if intermediates_only :
649+ return intermediates
650+
651+ x = self .ln_post (x )
652+
653+ return x , intermediates
654+
655+
566656
567657def checkpoint_filter_fn (
568658 state_dict : Dict [str , torch .Tensor ],
@@ -679,7 +769,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
679769 mlp_ratio = 4.0 ,
680770 output_dim = 1024 ,
681771 num_classes = 0 ,
682- use_cls_token = False ,
772+ use_cls_token = True ,
683773 use_ln_post = False ,
684774 pool_type = 'none' ,
685775 ls_init_value = 0.1 ,
0 commit comments