Skip to content

Commit da2fe8f

Browse files
authored
Fixrotary (#2511)
* fix rotary embed with interleave false
1 parent 212141b commit da2fe8f

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

onmt/modules/multi_headed_attn.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,36 +19,42 @@
1919
# are both < 2048 tokens.
2020

2121

22-
def rotaryembeddings(dim: int, maxseqlen=8192, base=10000):
22+
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000):
2323
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
2424
tmax = torch.arange(maxseqlen, device=inv_freq.device)
2525
rope = torch.outer(tmax, inv_freq).float()
2626
# rope is now matrix [maxseqlen, dim/2]
2727
rope = torch.polar(torch.ones_like(rope), rope)
28+
rope = torch.cat((rope, rope), dim=1)
2829
return rope
2930

3031

31-
def apply_rotary_emb(query, key, rope, interleave=True):
32-
query = query.transpose(1, 2)
33-
key = key.transpose(1, 2)
34-
if not interleave:
35-
query = torch.cat(
36-
(-query[..., query.shape[-1] // 2 :], query[..., : query.shape[-1] // 2]),
37-
dim=-1,
38-
)
39-
key = torch.cat(
40-
(-key[..., key.shape[-1] // 2 :], key[..., : key.shape[-1] // 2]), dim=-1
41-
)
42-
query_ = query.float().reshape(*query.shape[:-1], -1, 2)
43-
query_ = torch.view_as_complex(query_)
44-
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
45-
key_ = torch.view_as_complex(key_)
46-
rope = rope.view(1, query_.size(1), 1, query_.size(3))
47-
query_out = torch.view_as_real(query_ * rope).flatten(3)
48-
key_out = torch.view_as_real(key_ * rope).flatten(3)
49-
return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as(
50-
key
51-
)
32+
def rotate_half(x):
33+
"""Rotates half the hidden dims of the input."""
34+
x1 = x[..., : x.shape[-1] // 2]
35+
x2 = x[..., x.shape[-1] // 2 :]
36+
return torch.cat((-x2, x1), dim=-1)
37+
38+
39+
def apply_rotary_emb(query, key, rope, interleave):
40+
if interleave:
41+
query = query.transpose(1, 2)
42+
key = key.transpose(1, 2)
43+
query_ = query.float().reshape(*query.shape[:-1], -1, 2)
44+
query_ = torch.view_as_complex(query_)
45+
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
46+
key_ = torch.view_as_complex(key_)
47+
rope = rope[:, : rope.size(1) // 2].view(1, query_.size(1), 1, query_.size(3))
48+
query_out = torch.view_as_real(query_ * rope).flatten(3)
49+
key_out = torch.view_as_real(key_ * rope).flatten(3)
50+
return query_out.transpose(1, 2).type_as(query), key_out.transpose(
51+
1, 2
52+
).type_as(key)
53+
else:
54+
cos, sin = rope.real, rope.imag
55+
q_embed = (query * cos) + (rotate_half(query) * sin)
56+
k_embed = (key * cos) + (rotate_half(key) * sin)
57+
return q_embed.type_as(query), k_embed.type_as(key)
5258

5359

5460
# Help functions for max_relative positions
@@ -412,6 +418,10 @@ def forward(
412418
if self.max_relative_positions == -1: # Rotary Embeddings
413419
start_pos = step
414420
seqlen = query.size(2)
421+
if seqlen > self.rope.size(0):
422+
self.rope = rotaryembeddings(
423+
self.dim_per_head, maxseqlen=(seqlen + 2048)
424+
)
415425
rope = self.rope[start_pos : start_pos + seqlen]
416426
query, key = apply_rotary_emb(
417427
query, key, rope, interleave=self.rotary_interleave
@@ -444,14 +454,19 @@ def forward(
444454
key = self.maybe_ckpt(self.linear_keys, key)
445455
value = self.maybe_ckpt(self.linear_values, value)
446456
query = self.maybe_ckpt(self.linear_query, query)
457+
447458
key = shape(key, self.dim_per_head)
448459
value = shape(value, self.dim_per_head)
449460
query = shape(query, self.dim_per_head)
450461

451462
if self.max_relative_positions == -1: # Rotary Embeddings
452463
start_pos = 0
453464
seqlen = query.size(2)
454-
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
465+
if seqlen > self.rope.size(0):
466+
self.rope = rotaryembeddings(
467+
self.dim_per_head, maxseqlen=(seqlen + 2048)
468+
)
469+
rope = self.rope[start_pos : start_pos + seqlen]
455470
query, key = apply_rotary_emb(
456471
query, key, rope, interleave=self.rotary_interleave
457472
)
@@ -472,7 +487,6 @@ def forward(
472487
# Ultimately flashv2 will be part of pytorch https://github.com/pytorch/pytorch/pull/105602
473488
# In the meantime: if vanilla tranformer or Rotary embeddings (not rel_pos, not alibi)
474489
# then use flash2 if seq len > 256 otherwise use xtransformer from pt2 uptream
475-
476490
flash2 = (
477491
self.flash2
478492
and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591

0 commit comments

Comments
 (0)