44
55import torch
66import torch .nn as nn
7- from einops import rearrange , repeat
87from torch import nn , Tensor , broadcast_tensors , einsum
98from torch .nn import functional as F
109from torch .nn import Module , ModuleList
4948def exists (val ):
5049 return val is not None
5150
52-
5351def default (val , d ):
5452 return val if exists (val ) else d
5553
56-
5754def rotate_half (x ):
58- x = rearrange ( x , "... (d r) -> ... d r" , r = 2 )
59- x1 , x2 = x . unbind ( dim = - 1 )
55+ x = x . view ( * x . shape [: - 1 ], - 1 , 2 )
56+ x1 , x2 = x [..., 0 ], x [..., 1 ]
6057 x = torch .stack ((- x2 , x1 ), dim = - 1 )
61- return rearrange (x , "... d r -> ... (d r)" )
62-
58+ return x .view (* x .shape [:- 2 ], - 1 )
6359
6460@autocast ("cuda" , enabled = False )
6561def apply_rotary_emb (freqs , t , start_index = 0 , scale = 1.0 , seq_dim = - 2 ):
@@ -86,7 +82,6 @@ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
8682
8783 return out .type (dtype )
8884
89-
9085class RotaryEmbedding (Module ):
9186 def __init__ (
9287 self ,
@@ -187,7 +182,7 @@ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
187182 )
188183
189184 if seq_dim == - 3 :
190- freqs = rearrange ( freqs , "n d -> n 1 d" )
185+ freqs = freqs . unsqueeze ( 1 )
191186
192187 return apply_rotary_emb (freqs , t , seq_dim = seq_dim )
193188
@@ -217,8 +212,8 @@ def rotate_queries_and_keys(self, q, k, seq_dim=None):
217212 scale = self .get_scale (seq , seq_len = seq_len ).to (dtype )
218213
219214 if seq_dim == - 3 :
220- freqs = rearrange ( freqs , "n d -> n 1 d" )
221- scale = rearrange ( scale , "n d -> n 1 d" )
215+ freqs = freqs . unsqueeze ( 1 )
216+ scale = scale . unsqueeze ( 1 )
222217
223218 rotated_q = apply_rotary_emb (freqs , q , scale = scale , seq_dim = seq_dim )
224219 rotated_k = apply_rotary_emb (freqs , k , scale = scale ** - 1 , seq_dim = seq_dim )
@@ -230,7 +225,6 @@ def rotate_queries_and_keys(self, q, k, seq_dim=None):
230225
231226 def get_scale (self , t : Tensor , seq_len : Optional [int ] = None , offset = 0 ):
232227 assert self .use_xpos
233-
234228 should_cache = self .cache_if_possible and exists (seq_len )
235229
236230 if should_cache and exists (self .cached_scales ) and (seq_len + offset ) <= self .cached_scales .shape [0 ]:
@@ -239,7 +233,7 @@ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
239233 scale = 1.0
240234 if self .use_xpos :
241235 power = (t - len (t ) // 2 ) / self .scale_base
242- scale = self .scale ** rearrange ( power , "n -> n 1" )
236+ scale = self .scale ** power . unsqueeze ( - 1 )
243237 scale = torch .cat ((scale , scale ), dim = - 1 )
244238
245239 if should_cache :
@@ -280,7 +274,7 @@ def forward(self, t: Tensor, seq_len=None, offset=0):
280274 freqs = self .freqs
281275
282276 freqs = einsum ("..., f -> ... f" , t .type (freqs .dtype ), freqs )
283- freqs = repeat ( freqs , "... n -> ... (n r)" , r = 2 )
277+ freqs = freqs . repeat_interleave ( 2 , dim = - 1 )
284278
285279 if should_cache :
286280 self .tmp_store ("cached_freqs" , freqs .detach ())
@@ -414,15 +408,15 @@ def forward(self, x, attn_mask=None):
414408 q , k , v = proj [0 ], proj [1 ], proj [2 ]
415409
416410 # Use "q_" so that we don't accidentally quit in pdb :)
417- q = rearrange ( q , "b s (h d) -> b h s d" , h = self . num_heads )
418- k = rearrange ( k , "b s (h d) -> b h s d" , h = self . num_heads )
419- v = rearrange ( v , "b s (h d) -> b h s d" , h = self . num_heads )
411+ q = q . view ( batch , seq , self . num_heads , self . head_dim ). permute ( 0 , 2 , 1 , 3 )
412+ k = k . view ( batch , seq , self . num_heads , self . head_dim ). permute ( 0 , 2 , 1 , 3 )
413+ v = v . view ( batch , seq , self . num_heads , self . head_dim ). permute ( 0 , 2 , 1 , 3 )
420414
421415 if self .rope :
422416 q , k = self .rope (q , k )
423417
424418 attn = F .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = self .scale )
425- attn = rearrange ( attn , "b h s d -> b s (h d)" )
419+ attn = attn . permute ( 0 , 2 , 1 , 3 ). contiguous (). view ( batch , seq , - 1 )
426420
427421 return F .linear (attn , self .out_proj .weight , self .out_proj .bias )
428422
0 commit comments