Skip to content

Commit 2ba13ee

Browse files
committed
fix rope bug
1 parent 4746845 commit 2ba13ee

File tree

1 file changed

+54
-62
lines changed

1 file changed

+54
-62
lines changed

timm/models/pe.py

Lines changed: 54 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ class RotaryEmbedding(Module):
4141
def __init__(
4242
self,
4343
dim,
44-
freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang",
44+
freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang",
4545
theta=10000,
4646
max_freq=10,
4747
num_freqs=1,
48-
learned_freq=False,
48+
learned_freq=False,
4949
theta_rescale_factor=1.0,
5050
):
5151
super().__init__()
@@ -73,7 +73,6 @@ def forward(self, t: Tensor):
7373
return freqs
7474

7575

76-
7776
@register_notrace_module
7877
class Rope2D(Module):
7978
def __init__(self, dim, grid_size, use_cls_token=False):
@@ -83,10 +82,10 @@ def __init__(self, dim, grid_size, use_cls_token=False):
8382
self.grid_size = grid_size
8483
self.rope = RotaryEmbedding(self.dim // 2)
8584
self.init_tensors()
86-
85+
8786
def init_tensors(self):
8887
self.update_grid(self.grid_size[0], self.grid_size[1])
89-
88+
9089
def update_grid(self, grid_h, grid_w):
9190
if self.use_cls_token:
9291
# +1 to leave space for the cls token to be (0, 0)
@@ -100,22 +99,22 @@ def update_grid(self, grid_h, grid_w):
10099
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
101100

102101
if self.use_cls_token:
103-
freq = torch.cat([freq, torch.zeros(1, freq.shape[-1])], dim=0)
102+
freq = torch.cat([torch.zeros(1, freq.shape[-1]), freq], dim=0)
104103
self.register_buffer('freq', freq[None, ...], persistent=False)
105104

106105
def rotate_half(self, x):
107-
shape = x.shape
106+
shape = x.shape
108107
x = x.view(shape[:-1] + (-1, 2))
109108
x1, x2 = x[..., 0], x[..., 1]
110109
x = torch.stack((-x2, x1), dim=-1)
111110
return x.view(shape[:-1] + (-1,))
112-
111+
113112
def apply_rotary_emb(self, freqs, t):
114113
start_index = 0
115114
scale = 1.0
116115
seq_dim = -2
117116
dtype = t.dtype
118-
117+
119118
# if len(t.shape) == 3:
120119
# seq_len = t.shape[seq_dim]
121120
# freqs = freqs[-seq_len:]
@@ -185,6 +184,7 @@ class SelfAttention(nn.Module):
185184
r"""
186185
Implements sequence packed attention and RoPe
187186
"""
187+
188188
fused_attn: Final[bool]
189189

190190
def __init__(
@@ -214,11 +214,11 @@ def init_tensors(self):
214214
constant_(self.in_proj_bias, 0.0)
215215
constant_(self.out_proj.bias, 0.0)
216216

217-
218-
def forward(self,
219-
x: torch.Tensor,
220-
attn_mask: Optional[torch.Tensor] = None,
221-
):
217+
def forward(
218+
self,
219+
x: torch.Tensor,
220+
attn_mask: Optional[torch.Tensor] = None,
221+
):
222222
batch, seq, embed_dim = x.shape
223223
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
224224

@@ -235,7 +235,9 @@ def forward(self,
235235
q, k = self.rope(q, k)
236236

237237
if self.fused_attn:
238-
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
238+
attn = F.scaled_dot_product_attention(
239+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
240+
)
239241
else:
240242
q = q * self.scale
241243
attn = q @ k.transpose(-2, -1)
@@ -247,8 +249,6 @@ def forward(self,
247249
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
248250

249251

250-
251-
252252
class ResidualAttentionBlock(nn.Module):
253253
def __init__(
254254
self,
@@ -285,11 +285,7 @@ def __init__(
285285
)
286286
)
287287

288-
def _call_attn(
289-
self,
290-
q_x: torch.Tensor,
291-
attn_mask: Optional[torch.Tensor] = None
292-
):
288+
def _call_attn(self, q_x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
293289
if attn_mask is not None:
294290
# Leave boolean masks as is
295291
if not attn_mask.dtype == torch.bool:
@@ -300,7 +296,7 @@ def _call_attn(
300296
def forward(
301297
self,
302298
x: torch.Tensor,
303-
attn_mask: Optional[torch.Tensor] = None,
299+
attn_mask: Optional[torch.Tensor] = None,
304300
):
305301
x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)))
306302
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
@@ -354,18 +350,14 @@ def truncate(self, layer_idx: int):
354350
def forward(
355351
self,
356352
x: torch.Tensor,
357-
attn_mask: Optional[torch.Tensor] = None,
358-
# layer_idx=-1, #: int = -1, # torchscript emits iterations over modules as unrolled loops. so dynamic layer_idx is not supported as in orig pe
353+
attn_mask: Optional[torch.Tensor] = None,
359354
):
360-
#stop_idx = (self.layers + layer_idx) % self.layers
361355
for i, r in enumerate(self.resblocks):
362356
if self.grad_checkpointing and not torch.jit.is_scripting():
363357
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
364358
x = checkpoint(r, x, None, None, attn_mask)
365359
else:
366360
x = r(x, attn_mask=attn_mask)
367-
# if i == stop_idx:
368-
# break
369361
return x
370362

371363

@@ -389,11 +381,11 @@ def __init__(
389381
use_cls_token: bool = False,
390382
use_proj: bool = True,
391383
output_dim: Optional[int] = 1280,
392-
num_classes: int = 0,
384+
num_classes: int = 0,
393385
attn_pooler_heads: int = 8,
394386
use_attn_pool: bool = True,
395387
in_chans: int = 3,
396-
drop_rate: float = 0., # Expected to be here, TODO add a final drop layer once head finalized
388+
drop_rate: float = 0.0, # Expected to be here, TODO add a final drop layer once head finalized
397389
):
398390
super().__init__()
399391
self.patch_size = patch_size
@@ -404,7 +396,7 @@ def __init__(
404396
self.num_classes = num_classes
405397
self.drop_rate = drop_rate
406398
self.emb_dim = width
407-
399+
408400
# PE contains an (optional) projection layer
409401
# Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
410402
# forward_features: x -> Transfomer(x)
@@ -414,10 +406,10 @@ def __init__(
414406
if self.use_proj:
415407
self.proj_dim = output_dim
416408
self.head_hidden_size = self.proj_dim
417-
self.num_features = width # self.proj_dim
409+
self.num_features = width # self.proj_dim
418410
else:
419-
self.proj_dim = 0
420-
assert output_dim == width
411+
self.proj_dim = 0
412+
assert output_dim == width
421413
self.head_hidden_size = width
422414
self.num_features = width
423415

@@ -445,7 +437,7 @@ def __init__(
445437
Rope2D(
446438
dim=width // heads,
447439
use_cls_token=self.use_cls_token,
448-
grid_size = (img_size // patch_size, img_size // patch_size),
440+
grid_size=(img_size // patch_size, img_size // patch_size),
449441
)
450442
if self.use_rope2d
451443
else None
@@ -466,8 +458,7 @@ def __init__(
466458
rope=self.rope,
467459
)
468460

469-
self.feature_info = [
470-
dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)]
461+
self.feature_info = [dict(module=f'blocks.{i}', num_chs=width, reduction=patch_size) for i in range(layers)]
471462

472463
if use_attn_pool:
473464
self.attn_pool = AttentionPooling(
@@ -479,7 +470,7 @@ def __init__(
479470
else:
480471
self.attn_pool = None
481472

482-
self.head_act_layer = None # =act_layer if to add an additional activation between fc1(proj) and fc2(head)
473+
self.head_act_layer = None # =act_layer if to add an additional activation between fc1(proj) and fc2(head)
483474
self.init_tensors()
484475

485476
def init_tensors(self):
@@ -511,11 +502,11 @@ def init_submodule_tensors(module):
511502
# PE's: Transfomer(x) -> pool -> proj -> head (for timm). (PE contains an additional projection layer)
512503
if self.use_proj:
513504
self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim))
514-
else: # no projection (eg PE-lang and PE-spatial)
505+
else: # no projection (eg PE-lang and PE-spatial)
515506
self.proj = None
516507

517508
if self.num_classes > 0:
518-
self.head = nn.Linear(self.head_hidden_size, self.num_classes) # no proj. input dim = self.width (pooled)
509+
self.head = nn.Linear(self.head_hidden_size, self.num_classes) # no proj. input dim = self.width (pooled)
519510
else:
520511
self.head = nn.Identity()
521512

@@ -536,8 +527,8 @@ def forward_pool_and_proj(self, x: torch.Tensor):
536527
return x
537528

538529
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
539-
# PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
540-
# To discuss with Ross where to split
530+
# PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
531+
# To discuss with Ross where to split
541532
x = self.forward_pool_and_proj(x)
542533
if self.head_act_layer is not None:
543534
x = self.head_act_layer(x)
@@ -554,7 +545,7 @@ def forward_features(self, x: torch.Tensor, norm: bool = False):
554545
[self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
555546
dim=1,
556547
)
557-
548+
558549
if self.positional_embedding is not None:
559550
x = x + self.positional_embedding[None, ...]
560551

@@ -575,22 +566,22 @@ def reset_classifier(self, num_classes: int):
575566
if num_classes > 0:
576567
if self.proj is not None:
577568
self.head = nn.Parameter(self.proj_dim, num_classes)
578-
else: # no projection (eg PE-lang and PE-spatial)
569+
else: # no projection (eg PE-lang and PE-spatial)
579570
self.head = nn.Parameter(self.width, num_classes)
580571
else:
581572
self.head = nn.Identity()
582573

583574
def forward_intermediates(
584-
self,
585-
x: torch.Tensor,
586-
indices: Optional[Union[int, List[int]]] = None,
587-
return_prefix_tokens: bool = False,
588-
norm: bool = False,
589-
stop_early: bool = False,
590-
output_fmt: str = 'NCHW',
591-
intermediates_only: bool = False,
575+
self,
576+
x: torch.Tensor,
577+
indices: Optional[Union[int, List[int]]] = None,
578+
return_prefix_tokens: bool = False,
579+
norm: bool = False,
580+
stop_early: bool = False,
581+
output_fmt: str = 'NCHW',
582+
intermediates_only: bool = False,
592583
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
593-
""" Forward features that returns intermediates.
584+
"""Forward features that returns intermediates.
594585
595586
Args:
596587
x: Input image tensor
@@ -612,7 +603,7 @@ def forward_intermediates(
612603
B, _, height, width = x.shape
613604
# patch embedgging
614605
x = self.conv1(x)
615-
x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC
606+
x = x.permute(0, 2, 3, 1).reshape(B, -1, self.width) # NLC
616607

617608
if self.class_embedding is not None:
618609
x = torch.cat(
@@ -628,7 +619,7 @@ def forward_intermediates(
628619
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
629620
blocks = self.transformer.resblocks
630621
else:
631-
blocks = self.transformer.resblocks[:max_index + 1]
622+
blocks = self.transformer.resblocks[: max_index + 1]
632623

633624
for i, blk in enumerate(blocks):
634625
x = blk(x)
@@ -638,7 +629,7 @@ def forward_intermediates(
638629

639630
# process intermediates
640631
if self.class_embedding is not None:
641-
prefix_tokens = [y[:, 0] for y in intermediates] # only one cls token in PE
632+
prefix_tokens = [y[:, 0] for y in intermediates] # only one cls token in PE
642633
intermediates = [y[:, 1:] for y in intermediates]
643634
else:
644635
prefix_tokens = None
@@ -657,7 +648,6 @@ def forward_intermediates(
657648
x = self.ln_post(x)
658649

659650
return x, intermediates
660-
661651

662652

663653
def checkpoint_filter_fn(
@@ -677,18 +667,20 @@ def _cfg(url='', **kwargs):
677667
'num_classes': 0,
678668
'interpolation': 'bilinear',
679669
'fixed_input_size': True,
680-
'mean': IMAGENET_INCEPTION_MEAN, # (0.5, 0.5, 0.5)
681-
'std': IMAGENET_INCEPTION_STD, # (0.5, 0.5, 0.5)
670+
'mean': IMAGENET_INCEPTION_MEAN, # (0.5, 0.5, 0.5)
671+
'std': IMAGENET_INCEPTION_STD, # (0.5, 0.5, 0.5)
682672
'first_conv': 'conv1',
683-
'classifier': 'head',
673+
'classifier': 'head',
684674
**kwargs,
685675
}
686676

687677

688678
default_cfgs = generate_default_cfgs(
689679
{
690680
# TODO finalize locations
691-
'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='facebook/pe_core_base_patch16_224_timm', input_size=(3, 224, 224)),
681+
'vit_pe_core_base_patch16_224': _cfg(
682+
hf_hub_id='facebook/pe_core_base_patch16_224_timm', input_size=(3, 224, 224)
683+
),
692684
'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)),
693685
'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
694686
'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
@@ -822,4 +814,4 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
822814
ls_init_value=0.1,
823815
use_proj=False,
824816
)
825-
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
817+
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)