Skip to content

Commit dc45276

Browse files
talumbaucopybara-github
authored andcommitted
Build RoPE cos, sin tensors on demand
PiperOrigin-RevId: 707748616
1 parent 6abeb94 commit dc45276

File tree

5 files changed

+66
-86
lines changed

5 files changed

+66
-86
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2323
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
25+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2526
from ai_edge_torch.generative.utilities import model_builder
2627
import ai_edge_torch.generative.utilities.loader as loading_utils
2728
import torch
@@ -103,17 +104,12 @@ def __init__(self, config: cfg.ModelConfig):
103104
config.embedding_dim,
104105
config.final_norm_config,
105106
)
106-
# Gemma2 has same hyper parameters for each layer except for attention
107-
# types. Use the first layer.
108-
attn_config = config.block_config(0).attn_config
109-
self.rope_cache = attn_utils.build_rope_cache(
110-
size=config.kv_cache_max,
111-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112-
base=attn_config.rotary_base,
113-
)
114107
self.mask_cache = attn_utils.build_causal_mask_cache(
115108
size=config.kv_cache_max,
116109
)
110+
# Gemma2 has same hyper parameters for each layer except for attention
111+
# types. Use the first layer.
112+
attn_config = config.block_config(0).attn_config
117113
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118114
size=config.kv_cache_max,
119115
window_size=attn_config.sliding_window_size,
@@ -145,24 +141,27 @@ def forward(
145141
" must be the same."
146142
)
147143

148-
cos, sin = self.rope_cache
149-
cos = cos.index_select(0, input_pos)
150-
sin = sin.index_select(0, input_pos)
144+
# RoPE parameters are the same for all blocks. Use the first layer.
145+
attn_config = self.config.block_config(0).attn_config
146+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
147+
rope = rotary_pos_emb.build_rope(
148+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
149+
)
151150

152151
# token embeddings of shape (b, t, n_embd)
153152
x = self.tok_embedding(tokens)
154153
x = x * (self.config.embedding_dim**0.5)
155154

156-
updated_kv_entires = []
155+
updated_kv_entries = []
157156
for i, block in enumerate(self.transformer_blocks):
158157
mask = self.get_attention_mask(
159158
block.config.attn_config.attn_type, input_pos
160159
)
161160
kv_entry = kv_cache.caches[i] if kv_cache else None
162-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
161+
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
163162
if kv_entry:
164-
updated_kv_entires.append(kv_entry)
165-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
163+
updated_kv_entries.append(kv_entry)
164+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166165

167166
if export_config is not None:
168167
if (

ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ def forward(
7272
mask = self.mask_cache.index_select(2, input_pos)
7373
mask = mask[:, :, :, : self.config.max_seq_len]
7474

75-
updated_kv_entires = []
75+
updated_kv_entries = []
7676
for i, block in enumerate(self.transformer_blocks):
7777
kv_entry = kv_cache.caches[i] if kv_cache else None
7878
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
7979
if kv_entry:
80-
updated_kv_entires.append(kv_entry)
80+
updated_kv_entries.append(kv_entry)
8181

82-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
82+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
8383

8484
if export_config is not None:
8585
if (

ai_edge_torch/generative/layers/attention.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,6 @@
2626
from torch import nn
2727

2828

29-
def _embed_rope(
30-
q: torch.Tensor,
31-
k: torch.Tensor,
32-
n_elem: int,
33-
rope: Tuple[torch.Tensor, torch.Tensor],
34-
) -> Tuple[torch.Tensor, torch.Tensor]:
35-
"""Embed rotary positional embedding for query and key.
36-
37-
Args:
38-
q (torch.Tensor): query tensor.
39-
k (torch.Tensor): key tensor.
40-
n_elem (int): number of elements to embed rotarty positional embedding.
41-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42-
"""
43-
if n_elem > 0:
44-
cos, sin = rope
45-
q_roped = rotary_pos_emb.apply_rope(
46-
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47-
)
48-
k_roped = rotary_pos_emb.apply_rope(
49-
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50-
)
51-
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52-
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53-
return q, k
54-
55-
5629
class TransformerBlock(nn.Module):
5730

5831
def __init__(
@@ -238,7 +211,8 @@ def forward(
238211
if rope is not None:
239212
# Compute rotary positional embedding for query and key.
240213
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241-
q, k = _embed_rope(q, k, n_elem, rope)
214+
cos, sin = rope
215+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
242216

243217
if kv_cache is not None:
244218
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -374,7 +348,8 @@ def forward(
374348
if rope is not None:
375349
# Compute rotary positional embedding for query and key.
376350
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377-
q, k = _embed_rope(q, k, n_elem, rope)
351+
cos, sin = rope
352+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
378353

379354
if kv_cache is not None:
380355
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)

ai_edge_torch/generative/layers/rotary_position_embedding.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,57 +32,64 @@ def apply_rope(
3232
"""
3333
x = x.transpose(1, 2)
3434
head_size = x.size(-1)
35-
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36-
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37-
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38-
roped = (x * cos) + (rotated * sin)
35+
x1, x2 = torch.split(x, head_size // 2, dim=-1)
36+
left = x1 * cos - x2 * sin
37+
right = x2 * cos + x1 * sin
38+
roped = torch.cat([left, right], dim=-1)
3939
return roped.transpose(1, 2).type_as(x)
4040

4141

42-
def apply_rope_inline(
43-
q: torch.Tensor,
44-
k: torch.Tensor,
42+
def build_rope(
4543
input_pos: torch.Tensor,
4644
n_elem: int,
45+
head_dim: int,
4746
base: int = 10_000,
4847
) -> Tuple[torch.Tensor, torch.Tensor]:
49-
"""Computes rotary positional embedding inline for a query and key.
48+
"""Computes rotary positional embedding cosine and sine tensors.
5049
5150
Args:
52-
q: the query tensor.
53-
k: the key tensor.
5451
input_pos: the sequence indices for the query and key
5552
n_elem: number of elements of the head dimension for RoPE computation
53+
base: the base of the exponentiated value for RoPE.
5654
5755
Returns:
58-
output the RoPE'd query and key.
56+
cos, sin tensors
5957
"""
6058

6159
if n_elem <= 0:
62-
return q, k
60+
return None, None
6361

6462
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
6563
freq_exponents = (2.0 / n_elem) * torch.arange(
66-
q.shape[-1] // 2, dtype=torch.float32
64+
head_dim // 2, dtype=torch.float32
6765
)
6866
timescale = float(base) ** freq_exponents
6967
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
7068
0
7169
).unsqueeze(0)
72-
cos = torch.cos(radians).type_as(q)
73-
sin = torch.sin(radians).type_as(q)
70+
cos = torch.cos(radians)
71+
sin = torch.sin(radians)
72+
return cos, sin
73+
7474

75-
def apply(x, sin, cos):
76-
x = x.transpose(1, 2)
77-
b, h, s, d = x.shape
78-
ans = torch.split(x, d // 2, dim=-1)
79-
x1, x2 = ans
80-
left = x1 * cos - x2 * sin
81-
right = x2 * cos + x1 * sin
82-
res = torch.cat([left, right], dim=-1)
83-
res = res.transpose(1, 2)
84-
return res
75+
def apply_rope_inline(
76+
q: torch.Tensor,
77+
k: torch.Tensor,
78+
cos: torch.Tensor,
79+
sin: torch.Tensor,
80+
) -> Tuple[torch.Tensor, torch.Tensor]:
81+
"""Computes rotary positional embedding inline for a query and key.
82+
83+
Args:
84+
q: the query tensor.
85+
k: the key tensor.
86+
cos: the cosine tensor.
87+
sin: the sine tensor.
88+
89+
Returns:
90+
output the RoPE'd query and key.
91+
"""
8592

86-
q_roped = apply(q, sin, cos)
87-
k_roped = apply(k, sin, cos)
93+
q_roped = apply_rope(q, cos, sin)
94+
k_roped = apply_rope(k, cos, sin)
8895
return q_roped, k_roped

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2525
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2626
import ai_edge_torch.generative.layers.model_config as cfg
27+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2728
import ai_edge_torch.generative.utilities.loader as loading_utils
2829
import torch
2930
from torch import nn
@@ -85,13 +86,6 @@ def __init__(self, config: cfg.ModelConfig):
8586
config.embedding_dim,
8687
config.final_norm_config,
8788
)
88-
# ROPE parameters for all attn_configs are the same. Take the first one.
89-
attn_config = config.block_config(0).attn_config
90-
self.rope_cache = attn_utils.build_rope_cache(
91-
size=config.kv_cache_max,
92-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93-
base=attn_config.rotary_base,
94-
)
9589
self.mask_cache = attn_utils.build_causal_mask_cache(
9690
size=config.kv_cache_max,
9791
)
@@ -113,11 +107,16 @@ def forward(
113107

114108
# token embeddings of shape (b, t, n_embd)
115109
input_embeds = self.tok_embedding(tokens)
116-
cos, sin = self.rope_cache
117-
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
118110
mask = self.mask_cache.index_select(2, input_pos)
119111
mask = mask[:, :, :, : self.config.kv_cache_max]
120112

113+
# ROPE parameters for all attn_configs are the same. Take the first one.
114+
attn_config = self.config.block_config(0).attn_config
115+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
116+
rope = rotary_pos_emb.build_rope(
117+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
118+
)
119+
121120
return self.forward_with_embeds(
122121
input_embeds, rope, mask, input_pos, kv_cache, export_config
123122
)
@@ -141,13 +140,13 @@ def forward_with_embeds(
141140
if self.config.embedding_scale is not None:
142141
x = x * self.config.embedding_scale
143142

144-
updated_kv_entires = []
143+
updated_kv_entries = []
145144
for i, block in enumerate(self.transformer_blocks):
146145
kv_entry = kv_cache.caches[i] if kv_cache else None
147146
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
148147
if kv_entry:
149-
updated_kv_entires.append(kv_entry)
150-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
148+
updated_kv_entries.append(kv_entry)
149+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
151150

152151
if export_config is not None:
153152
if (

0 commit comments

Comments
 (0)