1919# are both < 2048 tokens.
2020
2121
22- def rotaryembeddings (dim : int , maxseqlen = 8192 , base = 10000 ):
22+ def rotaryembeddings (dim : int , maxseqlen = 2048 , base = 10000 ):
2323 inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 ).float () / dim ))
2424 tmax = torch .arange (maxseqlen , device = inv_freq .device )
2525 rope = torch .outer (tmax , inv_freq ).float ()
2626 # rope is now matrix [maxseqlen, dim/2]
2727 rope = torch .polar (torch .ones_like (rope ), rope )
28+ rope = torch .cat ((rope , rope ), dim = 1 )
2829 return rope
2930
3031
31- def apply_rotary_emb (query , key , rope , interleave = True ):
32- query = query .transpose (1 , 2 )
33- key = key .transpose (1 , 2 )
34- if not interleave :
35- query = torch .cat (
36- (- query [..., query .shape [- 1 ] // 2 :], query [..., : query .shape [- 1 ] // 2 ]),
37- dim = - 1 ,
38- )
39- key = torch .cat (
40- (- key [..., key .shape [- 1 ] // 2 :], key [..., : key .shape [- 1 ] // 2 ]), dim = - 1
41- )
42- query_ = query .float ().reshape (* query .shape [:- 1 ], - 1 , 2 )
43- query_ = torch .view_as_complex (query_ )
44- key_ = key .float ().reshape (* key .shape [:- 1 ], - 1 , 2 )
45- key_ = torch .view_as_complex (key_ )
46- rope = rope .view (1 , query_ .size (1 ), 1 , query_ .size (3 ))
47- query_out = torch .view_as_real (query_ * rope ).flatten (3 )
48- key_out = torch .view_as_real (key_ * rope ).flatten (3 )
49- return query_out .transpose (1 , 2 ).type_as (query ), key_out .transpose (1 , 2 ).type_as (
50- key
51- )
32+ def rotate_half (x ):
33+ """Rotates half the hidden dims of the input."""
34+ x1 = x [..., : x .shape [- 1 ] // 2 ]
35+ x2 = x [..., x .shape [- 1 ] // 2 :]
36+ return torch .cat ((- x2 , x1 ), dim = - 1 )
37+
38+
39+ def apply_rotary_emb (query , key , rope , interleave ):
40+ if interleave :
41+ query = query .transpose (1 , 2 )
42+ key = key .transpose (1 , 2 )
43+ query_ = query .float ().reshape (* query .shape [:- 1 ], - 1 , 2 )
44+ query_ = torch .view_as_complex (query_ )
45+ key_ = key .float ().reshape (* key .shape [:- 1 ], - 1 , 2 )
46+ key_ = torch .view_as_complex (key_ )
47+ rope = rope [:, : rope .size (1 ) // 2 ].view (1 , query_ .size (1 ), 1 , query_ .size (3 ))
48+ query_out = torch .view_as_real (query_ * rope ).flatten (3 )
49+ key_out = torch .view_as_real (key_ * rope ).flatten (3 )
50+ return query_out .transpose (1 , 2 ).type_as (query ), key_out .transpose (
51+ 1 , 2
52+ ).type_as (key )
53+ else :
54+ cos , sin = rope .real , rope .imag
55+ q_embed = (query * cos ) + (rotate_half (query ) * sin )
56+ k_embed = (key * cos ) + (rotate_half (key ) * sin )
57+ return q_embed .type_as (query ), k_embed .type_as (key )
5258
5359
5460# Help functions for max_relative positions
@@ -412,6 +418,10 @@ def forward(
412418 if self .max_relative_positions == - 1 : # Rotary Embeddings
413419 start_pos = step
414420 seqlen = query .size (2 )
421+ if seqlen > self .rope .size (0 ):
422+ self .rope = rotaryembeddings (
423+ self .dim_per_head , maxseqlen = (seqlen + 2048 )
424+ )
415425 rope = self .rope [start_pos : start_pos + seqlen ]
416426 query , key = apply_rotary_emb (
417427 query , key , rope , interleave = self .rotary_interleave
@@ -444,14 +454,19 @@ def forward(
444454 key = self .maybe_ckpt (self .linear_keys , key )
445455 value = self .maybe_ckpt (self .linear_values , value )
446456 query = self .maybe_ckpt (self .linear_query , query )
457+
447458 key = shape (key , self .dim_per_head )
448459 value = shape (value , self .dim_per_head )
449460 query = shape (query , self .dim_per_head )
450461
451462 if self .max_relative_positions == - 1 : # Rotary Embeddings
452463 start_pos = 0
453464 seqlen = query .size (2 )
454- rope = self .rope [start_pos : start_pos + seqlen ].to (query .device )
465+ if seqlen > self .rope .size (0 ):
466+ self .rope = rotaryembeddings (
467+ self .dim_per_head , maxseqlen = (seqlen + 2048 )
468+ )
469+ rope = self .rope [start_pos : start_pos + seqlen ]
455470 query , key = apply_rotary_emb (
456471 query , key , rope , interleave = self .rotary_interleave
457472 )
@@ -472,7 +487,6 @@ def forward(
472487 # Ultimately flashv2 will be part of pytorch https://github.com/pytorch/pytorch/pull/105602
473488 # In the meantime: if vanilla tranformer or Rotary embeddings (not rel_pos, not alibi)
474489 # then use flash2 if seq len > 256 otherwise use xtransformer from pt2 uptream
475-
476490 flash2 = (
477491 self .flash2
478492 and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591
0 commit comments