@@ -61,7 +61,7 @@ def __init__(
6161 if learned_freq :
6262 self .freqs = nn .Parameter (freqs )
6363 else :
64- self .freqs = nn . Buffer ( freqs , persistent = False )
64+ self .register_buffer ( ' freqs' , freqs , persistent = False )
6565
6666 def forward (self , t : Tensor ):
6767 freqs = self .freqs
@@ -97,7 +97,7 @@ def update_grid(self, grid_h, grid_w):
9797
9898 if self .use_cls_token :
9999 freq = torch .cat ([freq , torch .zeros (1 , freq .shape [- 1 ])], dim = 0 )
100- self .freq = nn . Buffer ( freq [None , ...], persistent = False )
100+ self .register_buffer ( ' freq' , freq [None , ...], persistent = False )
101101
102102 def rotate_half (self , x ):
103103 shape = x .shape
@@ -382,13 +382,16 @@ def __init__(
382382 attn_pooler_heads : int = 8 ,
383383 use_attn_pool : bool = True ,
384384 in_chans : int = 3 ,
385+ drop_rate : float = 0. , # Expected to be here, TODO add a final drop layer once head finalized
385386 ):
386387 super ().__init__ ()
387388 self .patch_size = patch_size
388389 self .heads = heads
389390 self .width = width
390391 self .layers = layers
391392 self .in_chans = in_chans
393+ self .num_classes = num_classes
394+ self .drop_rate = drop_rate
392395
393396 # PE contains an (optional) projection layer
394397 # Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
@@ -494,16 +497,13 @@ def init_submodule_tensors(module):
494497 # PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer)
495498 if self .use_proj :
496499 self .proj = nn .Parameter (init_scale * torch .randn (self .width , self .proj_dim ))
497- if self .num_classes > 0 :
498- self .head = nn .Linear (self .proj_dim , self .num_classes )
499- else :
500- self .head = nn .Identity ()
501500 else : # no projection (eg PE-lang and PE-spatial)
502501 self .proj = None
503- if self .num_classes > 0 :
504- self .head = nn .Linear (self .width , self .num_classes ) # no proj. input dim = self.width (pooled)
505- else :
506- self .head = nn .Identity ()
502+
503+ if self .num_classes > 0 :
504+ self .head = nn .Linear (self .head_hidden_size , self .num_classes ) # no proj. input dim = self.width (pooled)
505+ else :
506+ self .head = nn .Identity ()
507507
508508 def truncate (self , layer_idx : int ):
509509 """Delete layers so the last layer is the given layer index."""
@@ -671,7 +671,8 @@ def _cfg(url='', **kwargs):
671671
672672default_cfgs = generate_default_cfgs (
673673 {
674- 'vit_pe_core_base_patch16_224' : _cfg (hf_hub_id = 'timm/' , input_size = (3 , 224 , 224 )),
674+ # TODO finalize locations
675+ 'vit_pe_core_base_patch16_224' : _cfg (hf_hub_id = 'facebook/pe_core_base_patch16_224_timm' , input_size = (3 , 224 , 224 )),
675676 'vit_pe_core_large_patch14_336' : _cfg (hf_hub_id = 'timm/' , input_size = (3 , 336 , 336 )),
676677 'vit_pe_core_gigantic_patch14_448' : _cfg (hf_hub_id = 'timm/' , input_size = (3 , 448 , 448 )),
677678 'vit_pe_lang_large_patch14_448' : _cfg (hf_hub_id = 'timm/' , input_size = (3 , 448 , 448 )),
0 commit comments