Skip to content

Commit 6f452af

Browse files
committed
Adding Naver rope-vit compatibility to EVA ViT
1 parent 8d41071 commit 6f452af

File tree

3 files changed

+596
-38
lines changed

3 files changed

+596
-38
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit
5050
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
5151
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
52-
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
52+
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat, RotaryEmbeddingMixed
5353
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
5454
from .selective_kernel import SelectiveKernel
5555
from .separable_conv import SeparableConv2d, SeparableConvNormAct

timm/layers/pos_embed_sincos.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,6 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
219219

220220
def apply_rot_embed_cat(x: torch.Tensor, emb):
221221
sin_emb, cos_emb = emb.tensor_split(2, -1)
222-
if sin_emb.ndim == 3:
223-
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
224222
return x * cos_emb + rot(x) * sin_emb
225223

226224

@@ -351,6 +349,7 @@ def __init__(
351349
ref_feat_shape=self.ref_feat_shape,
352350
grid_offset=self.grid_offset,
353351
grid_indexing=self.grid_indexing,
352+
temperature=self.temperature,
354353
)
355354
self.bands = None
356355
self.register_buffer(
@@ -446,6 +445,7 @@ def __init__(
446445
ref_feat_shape=self.ref_feat_shape,
447446
grid_offset=self.grid_offset,
448447
grid_indexing=self.grid_indexing,
448+
temperature=self.temperature,
449449
)
450450
self.bands = None
451451
self.register_buffer(
@@ -475,3 +475,125 @@ def forward(self, x):
475475
# assuming channel-first tensor where spatial dim are >= 2
476476
pos_embed = self.get_embed(x.shape[2:])
477477
return apply_rot_embed_cat(x, pos_embed)
478+
479+
480+
def init_random_2d_freqs(
481+
head_dim: int,
482+
depth: int,
483+
num_heads: int,
484+
temperature: float = 10.0,
485+
rotate: bool = True,
486+
*,
487+
device=None,
488+
dtype=torch.float32,
489+
) -> torch.Tensor:
490+
""" Vectorised 2D ROPE frequencies with random rotation for mixed mode ROPE.
491+
Returns:
492+
Tensor (2, depth, num_heads, head_dim//2)
493+
"""
494+
# base magnitudes, shape: (head_dim//4,)
495+
mag = 1.0 / (temperature ** (torch.arange(0, head_dim, 4, device=device, dtype=dtype) / head_dim))
496+
497+
# (1,1,L) so it broadcasts over both depth and heads
498+
mag = mag.unsqueeze(0).unsqueeze(0) # (1,1,L)
499+
500+
# random (or zero) rotation per head *and* per block
501+
if rotate:
502+
angles = torch.rand(depth, num_heads, 1, device=device, dtype=dtype) * 2 * torch.pi
503+
else:
504+
angles = torch.zeros(depth, num_heads, 1, device=device, dtype=dtype)
505+
506+
# build (depth, num_heads, 2·L) == head_dim//2 on the last axis
507+
fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(angles + torch.pi / 2)], dim=-1)
508+
fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(angles + torch.pi / 2)], dim=-1)
509+
510+
# (2, depth, num_heads, head_dim//2)
511+
return torch.stack([fx, fy], dim=0)
512+
513+
514+
class RotaryEmbeddingMixed(nn.Module):
515+
"""Rotary position embedding with depth-dependent learnable frequencies.
516+
517+
This implementation supports mixed (learnable) ROPE. In mixed mode,
518+
each transformer block has its own set of learnable frequency parameters.
519+
"""
520+
def __init__(
521+
self,
522+
dim: int,
523+
depth: int,
524+
num_heads: int,
525+
temperature: float = 10.0,
526+
feat_shape: Optional[List[int]] = None,
527+
grid_indexing: str = 'xy',
528+
):
529+
"""Initialize rotary embeddings.
530+
531+
Args:
532+
dim: Embedding dimension (should be divisible by 4)
533+
depth: Number of transformer blocks
534+
num_heads: Number of attention heads
535+
temperature: Base for frequency computation
536+
feat_shape: Spatial dimensions [H, W] if known in advance
537+
grid_indexing: How to index grid positions ('xy' or 'ij')
538+
"""
539+
super().__init__()
540+
self.dim = dim
541+
self.depth = depth
542+
self.num_heads = num_heads
543+
self.temperature = temperature
544+
self.feat_shape = feat_shape
545+
self.grid_indexing = grid_indexing
546+
547+
head_dim = dim // num_heads
548+
assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}"
549+
freqs = init_random_2d_freqs(
550+
head_dim,
551+
depth,
552+
num_heads,
553+
temperature=temperature,
554+
rotate=True,
555+
) # (2, depth, num_heads, head_dim//2)
556+
self.freqs = nn.Parameter(freqs)
557+
558+
def get_mixed_freqs(self, H: int, W: int, device: torch.device, dtype: torch.dtype):
559+
"""Compute mixed (learnable) frequencies."""
560+
# Create position indices
561+
x_pos, y_pos = torch.meshgrid(
562+
torch.arange(H, dtype=dtype, device=device),
563+
torch.arange(W, dtype=dtype, device=device),
564+
indexing=self.grid_indexing,
565+
)
566+
t_x = x_pos.flatten()
567+
t_y = y_pos.flatten()
568+
freqs_x = (t_x.unsqueeze(-1) @ self.freqs[0].unsqueeze(-2))
569+
freqs_y = (t_y.unsqueeze(-1) @ self.freqs[1].unsqueeze(-2))
570+
combined = freqs_x + freqs_y # shape: (num_heads, N, dim//4)
571+
sin_emb = torch.sin(combined).repeat_interleave(2, -1) # (N, dim//2)
572+
cos_emb = torch.cos(combined).repeat_interleave(2, -1) # (N, dim//2)
573+
rope_embeds = torch.cat([sin_emb, cos_emb], dim=-1) # (num_heads, H*W, head_dim)
574+
575+
return rope_embeds
576+
577+
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
578+
"""Generate rotary embeddings for the given spatial shape.
579+
580+
Args:
581+
shape: Spatial dimensions [H, W]
582+
583+
Returns:
584+
Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings
585+
"""
586+
assert shape is not None, "shape must be provided"
587+
H, W = shape
588+
device = self.freqs.device
589+
dtype = self.freqs.dtype
590+
return self.get_mixed_freqs(H, W, device, dtype)
591+
592+
def forward(self, x):
593+
# assuming channel-first tensor where spatial dim are >= 2
594+
pos_embed = self.get_embed(x.shape[2:])
595+
return apply_rot_embed_cat(x, pos_embed)
596+
597+
def no_weight_decay(self):
598+
"""Exclude frequency parameters from weight decay."""
599+
return {'freqs'}

0 commit comments

Comments
 (0)