@@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
386386 return self ._forward (x )
387387
388388
389+ def global_pool_nlc (
390+ x : torch .Tensor ,
391+ pool_type : str = 'token' ,
392+ num_prefix_tokens : int = 1 ,
393+ reduce_include_prefix : bool = False ,
394+ ):
395+ if not pool_type :
396+ return x
397+
398+ if pool_type == 'token' :
399+ x = x [:, 0 ] # class token
400+ else :
401+ x = x if reduce_include_prefix else x [:, num_prefix_tokens :]
402+ if pool_type == 'avg' :
403+ x = x .mean (dim = 1 )
404+ elif pool_type == 'avgmax' :
405+ x = 0.5 * (x .amax (dim = 1 ) + x .mean (dim = 1 ))
406+ elif pool_type == 'max' :
407+ x = x .amax (dim = 1 )
408+ else :
409+ assert not pool_type , f'Unknown pool type { pool_type } '
410+
411+ return x
412+
413+
389414class VisionTransformer (nn .Module ):
390415 """ Vision Transformer
391416
@@ -400,7 +425,7 @@ def __init__(
400425 patch_size : Union [int , Tuple [int , int ]] = 16 ,
401426 in_chans : int = 3 ,
402427 num_classes : int = 1000 ,
403- global_pool : Literal ['' , 'avg' , 'max' , 'token' , 'map' ] = 'token' ,
428+ global_pool : Literal ['' , 'avg' , 'avgmax' , ' max' , 'token' , 'map' ] = 'token' ,
404429 embed_dim : int = 768 ,
405430 depth : int = 12 ,
406431 num_heads : int = 12 ,
@@ -459,10 +484,10 @@ def __init__(
459484 block_fn: Transformer block layer.
460485 """
461486 super ().__init__ ()
462- assert global_pool in ('' , 'avg' , 'max' , 'token' , 'map' )
487+ assert global_pool in ('' , 'avg' , 'avgmax' , ' max' , 'token' , 'map' )
463488 assert class_token or global_pool != 'token'
464489 assert pos_embed in ('' , 'none' , 'learn' )
465- use_fc_norm = global_pool in [ 'avg' , 'max' ] if fc_norm is None else fc_norm
490+ use_fc_norm = global_pool in ( 'avg' , 'avgmax' , ' max') if fc_norm is None else fc_norm
466491 norm_layer = get_norm_layer (norm_layer ) or partial (nn .LayerNorm , eps = 1e-6 )
467492 act_layer = get_act_layer (act_layer ) or nn .GELU
468493
@@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None:
596621 def get_classifier (self ) -> nn .Module :
597622 return self .head
598623
599- def reset_classifier (self , num_classes : int , global_pool = None ) -> None :
624+ def reset_classifier (self , num_classes : int , global_pool : Optional [ str ] = None ):
600625 self .num_classes = num_classes
601626 if global_pool is not None :
602- assert global_pool in ('' , 'avg' , 'token' , 'map' )
627+ assert global_pool in ('' , 'avg' , 'avgmax' , 'max' , ' token' , 'map' )
603628 if global_pool == 'map' and self .attn_pool is None :
604629 assert False , "Cannot currently add attention pooling in reset_classifier()."
605630 elif global_pool != 'map ' and self .attn_pool is not None :
@@ -756,15 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
756781 x = self .norm (x )
757782 return x
758783
759- def forward_head (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
784+ def pool (self , x : torch .Tensor , pool_type : Optional [ str ] = None ) -> torch .Tensor :
760785 if self .attn_pool is not None :
761786 x = self .attn_pool (x )
762- elif self .global_pool == 'avg' :
763- x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
764- elif self .global_pool == 'max' :
765- x , _ = torch .max (x [:, self .num_prefix_tokens :], dim = 1 )
766- elif self .global_pool :
767- x = x [:, 0 ] # class token
787+ return x
788+ pool_type = self .global_pool if pool_type is None else pool_type
789+ x = global_pool_nlc (x , pool_type = pool_type , num_prefix_tokens = self .num_prefix_tokens )
790+ return x
791+
792+ def forward_head (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
793+ x = self .pool (x )
768794 x = self .fc_norm (x )
769795 x = self .head_drop (x )
770796 return x if pre_logits else self .head (x )
0 commit comments