Skip to content

Commit 89d348d

Browse files
committed
PE model working with timm train script, fix nn.Buffer -> register_buffer, add drop_rate arg
1 parent 414b775 commit 89d348d

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

timm/models/pe.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
if learned_freq:
6262
self.freqs = nn.Parameter(freqs)
6363
else:
64-
self.freqs = nn.Buffer(freqs, persistent=False)
64+
self.register_buffer('freqs', freqs, persistent=False)
6565

6666
def forward(self, t: Tensor):
6767
freqs = self.freqs
@@ -97,7 +97,7 @@ def update_grid(self, grid_h, grid_w):
9797

9898
if self.use_cls_token:
9999
freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0)
100-
self.freq = nn.Buffer(freq[None, ...], persistent=False)
100+
self.register_buffer('freq', freq[None, ...], persistent=False)
101101

102102
def rotate_half(self, x):
103103
shape = x.shape
@@ -382,13 +382,16 @@ def __init__(
382382
attn_pooler_heads: int = 8,
383383
use_attn_pool: bool = True,
384384
in_chans: int = 3,
385+
drop_rate: float = 0., # Expected to be here, TODO add a final drop layer once head finalized
385386
):
386387
super().__init__()
387388
self.patch_size = patch_size
388389
self.heads = heads
389390
self.width = width
390391
self.layers = layers
391392
self.in_chans = in_chans
393+
self.num_classes = num_classes
394+
self.drop_rate = drop_rate
392395

393396
# PE contains an (optional) projection layer
394397
# Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
@@ -494,16 +497,13 @@ def init_submodule_tensors(module):
494497
# PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer)
495498
if self.use_proj:
496499
self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim))
497-
if self.num_classes > 0:
498-
self.head = nn.Linear(self.proj_dim, self.num_classes)
499-
else:
500-
self.head = nn.Identity()
501500
else: # no projection (eg PE-lang and PE-spatial)
502501
self.proj = None
503-
if self.num_classes > 0:
504-
self.head = nn.Linear(self.width, self.num_classes) # no proj. input dim = self.width (pooled)
505-
else:
506-
self.head = nn.Identity()
502+
503+
if self.num_classes > 0:
504+
self.head = nn.Linear(self.head_hidden_size, self.num_classes) # no proj. input dim = self.width (pooled)
505+
else:
506+
self.head = nn.Identity()
507507

508508
def truncate(self, layer_idx: int):
509509
"""Delete layers so the last layer is the given layer index."""
@@ -671,7 +671,8 @@ def _cfg(url='', **kwargs):
671671

672672
default_cfgs = generate_default_cfgs(
673673
{
674-
'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)),
674+
# TODO finalize locations
675+
'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='facebook/pe_core_base_patch16_224_timm', input_size=(3, 224, 224)),
675676
'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)),
676677
'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
677678
'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),

0 commit comments

Comments
 (0)