@@ -219,8 +219,6 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
219219
220220def apply_rot_embed_cat (x : torch .Tensor , emb ):
221221 sin_emb , cos_emb = emb .tensor_split (2 , - 1 )
222- if sin_emb .ndim == 3 :
223- return x * cos_emb .unsqueeze (1 ).expand_as (x ) + rot (x ) * sin_emb .unsqueeze (1 ).expand_as (x )
224222 return x * cos_emb + rot (x ) * sin_emb
225223
226224
@@ -351,6 +349,7 @@ def __init__(
351349 ref_feat_shape = self .ref_feat_shape ,
352350 grid_offset = self .grid_offset ,
353351 grid_indexing = self .grid_indexing ,
352+ temperature = self .temperature ,
354353 )
355354 self .bands = None
356355 self .register_buffer (
@@ -446,6 +445,7 @@ def __init__(
446445 ref_feat_shape = self .ref_feat_shape ,
447446 grid_offset = self .grid_offset ,
448447 grid_indexing = self .grid_indexing ,
448+ temperature = self .temperature ,
449449 )
450450 self .bands = None
451451 self .register_buffer (
@@ -475,3 +475,125 @@ def forward(self, x):
475475 # assuming channel-first tensor where spatial dim are >= 2
476476 pos_embed = self .get_embed (x .shape [2 :])
477477 return apply_rot_embed_cat (x , pos_embed )
478+
479+
480+ def init_random_2d_freqs (
481+ head_dim : int ,
482+ depth : int ,
483+ num_heads : int ,
484+ temperature : float = 10.0 ,
485+ rotate : bool = True ,
486+ * ,
487+ device = None ,
488+ dtype = torch .float32 ,
489+ ) -> torch .Tensor :
490+ """ Vectorised 2D ROPE frequencies with random rotation for mixed mode ROPE.
491+ Returns:
492+ Tensor (2, depth, num_heads, head_dim//2)
493+ """
494+ # base magnitudes, shape: (head_dim//4,)
495+ mag = 1.0 / (temperature ** (torch .arange (0 , head_dim , 4 , device = device , dtype = dtype ) / head_dim ))
496+
497+ # (1,1,L) so it broadcasts over both depth and heads
498+ mag = mag .unsqueeze (0 ).unsqueeze (0 ) # (1,1,L)
499+
500+ # random (or zero) rotation per head *and* per block
501+ if rotate :
502+ angles = torch .rand (depth , num_heads , 1 , device = device , dtype = dtype ) * 2 * torch .pi
503+ else :
504+ angles = torch .zeros (depth , num_heads , 1 , device = device , dtype = dtype )
505+
506+ # build (depth, num_heads, 2·L) == head_dim//2 on the last axis
507+ fx = torch .cat ([mag * torch .cos (angles ), mag * torch .cos (angles + torch .pi / 2 )], dim = - 1 )
508+ fy = torch .cat ([mag * torch .sin (angles ), mag * torch .sin (angles + torch .pi / 2 )], dim = - 1 )
509+
510+ # (2, depth, num_heads, head_dim//2)
511+ return torch .stack ([fx , fy ], dim = 0 )
512+
513+
514+ class RotaryEmbeddingMixed (nn .Module ):
515+ """Rotary position embedding with depth-dependent learnable frequencies.
516+
517+ This implementation supports mixed (learnable) ROPE. In mixed mode,
518+ each transformer block has its own set of learnable frequency parameters.
519+ """
520+ def __init__ (
521+ self ,
522+ dim : int ,
523+ depth : int ,
524+ num_heads : int ,
525+ temperature : float = 10.0 ,
526+ feat_shape : Optional [List [int ]] = None ,
527+ grid_indexing : str = 'xy' ,
528+ ):
529+ """Initialize rotary embeddings.
530+
531+ Args:
532+ dim: Embedding dimension (should be divisible by 4)
533+ depth: Number of transformer blocks
534+ num_heads: Number of attention heads
535+ temperature: Base for frequency computation
536+ feat_shape: Spatial dimensions [H, W] if known in advance
537+ grid_indexing: How to index grid positions ('xy' or 'ij')
538+ """
539+ super ().__init__ ()
540+ self .dim = dim
541+ self .depth = depth
542+ self .num_heads = num_heads
543+ self .temperature = temperature
544+ self .feat_shape = feat_shape
545+ self .grid_indexing = grid_indexing
546+
547+ head_dim = dim // num_heads
548+ assert head_dim % 4 == 0 , f"head_dim must be divisible by 4, got { head_dim } "
549+ freqs = init_random_2d_freqs (
550+ head_dim ,
551+ depth ,
552+ num_heads ,
553+ temperature = temperature ,
554+ rotate = True ,
555+ ) # (2, depth, num_heads, head_dim//2)
556+ self .freqs = nn .Parameter (freqs )
557+
558+ def get_mixed_freqs (self , H : int , W : int , device : torch .device , dtype : torch .dtype ):
559+ """Compute mixed (learnable) frequencies."""
560+ # Create position indices
561+ x_pos , y_pos = torch .meshgrid (
562+ torch .arange (H , dtype = dtype , device = device ),
563+ torch .arange (W , dtype = dtype , device = device ),
564+ indexing = self .grid_indexing ,
565+ )
566+ t_x = x_pos .flatten ()
567+ t_y = y_pos .flatten ()
568+ freqs_x = (t_x .unsqueeze (- 1 ) @ self .freqs [0 ].unsqueeze (- 2 ))
569+ freqs_y = (t_y .unsqueeze (- 1 ) @ self .freqs [1 ].unsqueeze (- 2 ))
570+ combined = freqs_x + freqs_y # shape: (num_heads, N, dim//4)
571+ sin_emb = torch .sin (combined ).repeat_interleave (2 , - 1 ) # (N, dim//2)
572+ cos_emb = torch .cos (combined ).repeat_interleave (2 , - 1 ) # (N, dim//2)
573+ rope_embeds = torch .cat ([sin_emb , cos_emb ], dim = - 1 ) # (num_heads, H*W, head_dim)
574+
575+ return rope_embeds
576+
577+ def get_embed (self , shape : Optional [List [int ]] = None ) -> torch .Tensor :
578+ """Generate rotary embeddings for the given spatial shape.
579+
580+ Args:
581+ shape: Spatial dimensions [H, W]
582+
583+ Returns:
584+ Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings
585+ """
586+ assert shape is not None , "shape must be provided"
587+ H , W = shape
588+ device = self .freqs .device
589+ dtype = self .freqs .dtype
590+ return self .get_mixed_freqs (H , W , device , dtype )
591+
592+ def forward (self , x ):
593+ # assuming channel-first tensor where spatial dim are >= 2
594+ pos_embed = self .get_embed (x .shape [2 :])
595+ return apply_rot_embed_cat (x , pos_embed )
596+
597+ def no_weight_decay (self ):
598+ """Exclude frequency parameters from weight decay."""
599+ return {'freqs' }
0 commit comments