@@ -246,10 +246,10 @@ def __init__(
246246 ):
247247 super ().__init__ ()
248248
249- if rope :
250- self .attn = SelfAttention (d_model , n_head , rope = rope )
251- else :
252- self .attn = nn .MultiheadAttention (d_model , n_head , batch_first = True )
249+ # if rope:
250+ self .attn = SelfAttention (d_model , n_head , rope = rope )
251+ # else:
252+ # self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
253253
254254 self .ls_1 = LayerScale (d_model , ls_init_value ) if ls_init_value is not None else nn .Identity ()
255255 self .ls_2 = LayerScale (d_model , ls_init_value ) if ls_init_value is not None else nn .Identity ()
@@ -281,10 +281,10 @@ def _call_attn(
281281 if not attn_mask .dtype == torch .bool :
282282 attn_mask = attn_mask .to (q_x .dtype )
283283
284- if isinstance (self .attn , SelfAttention ):
285- return self .attn (q_x , attn_mask = attn_mask )
286- else :
287- return self .attn (q_x , q_x , q_x , attn_mask = attn_mask , need_weights = False )[0 ]
284+ # if isinstance(self.attn, SelfAttention):
285+ return self .attn (q_x , attn_mask = attn_mask )
286+ # else:
287+ # return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
288288
289289 def forward (
290290 self ,
@@ -380,19 +380,16 @@ def __init__(
380380 output_dim : Optional [int ] = 1280 ,
381381 num_classes : int = 0 ,
382382 attn_pooler_heads : int = 8 ,
383- pool_type : Literal [ "attn" , "tok" , "avg" , "none" ] = "attn" ,
383+ use_attn_pool : bool = True ,
384384 in_chans : int = 3 ,
385385 ):
386386 super ().__init__ ()
387- assert pool_type in ("attn" , "tok" , "avg" , "none" )
388- self .pool_type = pool_type
389-
390387 self .patch_size = patch_size
391388 self .heads = heads
392389 self .width = width
393390 self .layers = layers
394391 self .in_chans = in_chans
395-
392+
396393 # PE contains an (optional) projection layer
397394 # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
398395 # forward_features: x -> Transfomer(x)
@@ -418,6 +415,7 @@ def __init__(
418415 if isinstance (img_size , (tuple , list )):
419416 img_size = img_size [0 ]
420417 self .img_size = img_size
418+ self .grid_size = self .img_size // self .patch_size
421419
422420 self .conv1 = nn .Conv2d (
423421 in_channels = in_chans ,
@@ -455,7 +453,7 @@ def __init__(
455453 self .feature_info = [
456454 dict (module = f'blocks.{ i } ' , num_chs = width , reduction = patch_size ) for i in range (layers )]
457455
458- if pool_type == "attn" :
456+ if use_attn_pool :
459457 self .attn_pool = AttentionPooling (
460458 embed_dim = width ,
461459 num_heads = attn_pooler_heads ,
@@ -483,12 +481,15 @@ def init_submodule_tensors(module):
483481
484482 if self .use_cls_token :
485483 self .class_embedding = nn .Parameter (init_scale * torch .randn (self .width ))
484+ else :
485+ self .class_embedding = None
486486
487487 if self .use_abs_posemb :
488- self .posemb_grid_size = self .img_size // self .patch_size
489488 self .positional_embedding = nn .Parameter (
490- init_scale * torch .randn (int (self .use_cls_token ) + self .posemb_grid_size ** 2 , self .width )
489+ init_scale * torch .randn (int (self .use_cls_token ) + self .grid_size ** 2 , self .width )
491490 )
491+ else :
492+ self .positional_embedding = None
492493
493494 # PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer)
494495 if self .use_proj :
@@ -498,7 +499,7 @@ def init_submodule_tensors(module):
498499 else :
499500 self .head = nn .Identity ()
500501 else : # no projection (eg PE-lang and PE-spatial)
501- self .proj = nn . Identity ()
502+ self .proj = None
502503 if self .num_classes > 0 :
503504 self .head = nn .Linear (self .width , self .num_classes ) # no proj. input dim = self.width (pooled)
504505 else :
@@ -514,15 +515,9 @@ def set_grad_checkpointing(self, enable=True):
514515 self .transformer .set_grad_checkpointing (enable = enable )
515516
516517 def forward_pool_and_proj (self , x : torch .Tensor ):
517- if self .pool_type == "tok" :
518- x = x [:, 0 ]
519- elif self .pool_type == "avg" :
520- x = x .mean (dim = 1 )
521- elif self .pool_type == "attn" :
518+ if self .attn_pool is not None :
522519 x = self .attn_pool (x ).squeeze (1 )
523- elif self .pool_type == "none" :
524- x = x
525- if self .use_proj :
520+ if self .proj is not None :
526521 x = x @ self .proj
527522 return x
528523
@@ -532,30 +527,26 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
532527 x = self .forward_pool_and_proj (x )
533528 return x if pre_logits else self .head (x )
534529
535- #def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
536530 def forward_features (self , x : torch .Tensor , norm : bool = False ):
537- #: layer_idx = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported in timm as in orig pe
538531 batch , _ , h , w = x .shape
539532
540533 x = self .conv1 (x )
541534 x = x .permute (0 , 2 , 3 , 1 ).reshape (batch , - 1 , self .width )
542535
543- if self .use_cls_token :
536+ if self .class_embedding is not None :
544537 x = torch .cat (
545538 [self .class_embedding .view (1 , 1 , - 1 ).expand (batch , - 1 , - 1 ), x ],
546539 dim = 1 ,
547540 )
548-
549- if self .use_abs_posemb :
541+
542+ if self .positional_embedding is not None :
550543 x = x + self .positional_embedding [None , ...]
551544
552545 x = self .ln_pre (x )
553546 x = self .transformer (x )
554547 if norm :
555548 x = self .ln_post (x )
556549
557- # if strip_cls_token and self.use_cls_token:
558- # x = x[:, 1:, :]
559550 return x
560551
561552 def forward (self , x : torch .Tensor ):
@@ -566,7 +557,7 @@ def forward(self, x: torch.Tensor):
566557 def reset_classifier (self , num_classes : int ):
567558 self .num_classes = num_classes
568559 if num_classes > 0 :
569- if self .proj_dim > 0 :
560+ if self .proj is not None :
570561 self .head = nn .Parameter (self .proj_dim , num_classes )
571562 else : # no projection (eg PE-lang and PE-spatial)
572563 self .head = nn .Parameter (self .width , num_classes )
@@ -603,18 +594,17 @@ def forward_intermediates(
603594
604595 # forward pass
605596 B , _ , height , width = x .shape
606-
607-
597+ # patch embedgging
608598 x = self .conv1 (x )
609599 x = x .permute (0 , 2 , 3 , 1 ).reshape (B , - 1 , self .width ) # NLC
610600
611- if self .use_cls_token :
601+ if self .class_embedding is not None :
612602 x = torch .cat (
613603 [self .class_embedding .view (1 , 1 , - 1 ).expand (B , - 1 , - 1 ), x ],
614604 dim = 1 ,
615605 )
616606
617- if self .use_abs_posemb :
607+ if self .positional_embedding is not None :
618608 x = x + self .positional_embedding [None , ...]
619609
620610 x = self .ln_pre (x )
@@ -631,15 +621,15 @@ def forward_intermediates(
631621 intermediates .append (self .norm (x ) if norm else x )
632622
633623 # process intermediates
634- if self .use_cls_token :
635- prefix_tokens = [y [:, 0 ] for y in intermediates ]
624+ if self .class_embedding is not None :
625+ prefix_tokens = [y [:, 0 ] for y in intermediates ] # only one cls token in PE
636626 intermediates = [y [:, 1 :] for y in intermediates ]
637627 else :
638628 prefix_tokens = None
639629
640630 if reshape :
641631 # reshape to BCHW output format
642- H = W = self .posemb_grid_size
632+ H = W = self .grid_size
643633 intermediates = [y .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous () for y in intermediates ]
644634 if not torch .jit .is_scripting () and return_prefix_tokens and prefix_tokens is not None :
645635 # return_prefix not support in torchscript due to poor type handling
@@ -716,7 +706,7 @@ def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
716706 output_dim = 1024 ,
717707 num_classes = 0 ,
718708 use_cls_token = True ,
719- pool_type = 'attn' ,
709+ use_attn_pool = True ,
720710 use_proj = True ,
721711 )
722712 return _create_pe ('vit_pe_core_base_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
@@ -734,7 +724,7 @@ def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
734724 output_dim = 1024 ,
735725 num_classes = 0 ,
736726 use_cls_token = True ,
737- pool_type = 'attn' ,
727+ use_attn_pool = True ,
738728 use_proj = True ,
739729 )
740730 return _create_pe ('vit_pe_core_large_patch14_336' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
@@ -752,7 +742,7 @@ def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
752742 output_dim = 1280 ,
753743 num_classes = 0 ,
754744 use_cls_token = False ,
755- pool_type = 'attn' ,
745+ use_attn_pool = True ,
756746 use_proj = True ,
757747 )
758748 return _create_pe ('vit_pe_core_gigantic_patch14_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
@@ -771,7 +761,7 @@ def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
771761 num_classes = 0 ,
772762 use_cls_token = True ,
773763 use_ln_post = False ,
774- pool_type = 'none' ,
764+ use_attn_pool = False ,
775765 ls_init_value = 0.1 ,
776766 use_proj = False ,
777767 )
@@ -791,7 +781,7 @@ def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
791781 num_classes = 0 ,
792782 use_cls_token = False ,
793783 use_ln_post = False ,
794- pool_type = 'none' ,
784+ use_attn_pool = False ,
795785 ls_init_value = 0.1 ,
796786 use_proj = False ,
797787 )
@@ -811,7 +801,7 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
811801 num_classes = 0 ,
812802 use_cls_token = False ,
813803 use_ln_post = False ,
814- pool_type = 'none' ,
804+ use_attn_pool = False ,
815805 ls_init_value = 0.1 ,
816806 use_proj = False ,
817807 )
0 commit comments