Skip to content

Commit b9c7180

Browse files
talumbaucopybara-github
authored andcommitted
Build RoPE cos, sin tensors on demand
PiperOrigin-RevId: 712561555
1 parent 5fb930e commit b9c7180

File tree

5 files changed

+98
-98
lines changed

5 files changed

+98
-98
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515

1616
"""Example of building a Gemma2 model."""
1717

18-
from typing import List, Optional, Tuple
18+
from typing import Optional, Tuple
1919

2020
from ai_edge_torch.generative.layers import attention
2121
from ai_edge_torch.generative.layers import builder
2222
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2323
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
25-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2625
from ai_edge_torch.generative.utilities import model_builder
2726
import ai_edge_torch.generative.utilities.loader as loading_utils
2827
import torch
@@ -104,12 +103,17 @@ def __init__(self, config: cfg.ModelConfig):
104103
config.embedding_dim,
105104
config.final_norm_config,
106105
)
107-
self.mask_cache = attn_utils.build_causal_mask_cache(
108-
size=config.kv_cache_max,
109-
)
110106
# Gemma2 has same hyper parameters for each layer except for attention
111107
# types. Use the first layer.
112108
attn_config = config.block_config(0).attn_config
109+
self.rope_cache = attn_utils.build_rope_cache(
110+
size=config.kv_cache_max,
111+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112+
base=attn_config.rotary_base,
113+
)
114+
self.mask_cache = attn_utils.build_causal_mask_cache(
115+
size=config.kv_cache_max,
116+
)
113117
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
114118
size=config.kv_cache_max,
115119
window_size=attn_config.sliding_window_size,
@@ -136,48 +140,29 @@ def forward(
136140
f"Cannot forward sequence of length {seq_len}, max seq length is only"
137141
f" {self.config.max_seq_len}"
138142
)
139-
140-
# token embeddings of shape (b, t, n_embd)
141-
input_embeds = self.tok_embedding(tokens)
142-
# RoPE parameters are the same for all blocks. Use the first layer.
143-
attn_config = self.config.block_config(0).attn_config
144-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145-
rope = rotary_pos_emb.build_rope(
146-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147-
)
148-
mask = [self.get_attention_mask(
149-
self.config.block_config(i).attn_config.attn_type, input_pos
150-
) for i in range(self.config.num_layers)]
151-
152-
return self._forward_with_embeds(
153-
input_embeds, rope, mask, input_pos, kv_cache, export_config
154-
)
155-
156-
def _forward_with_embeds(
157-
self,
158-
input_embeds: torch.Tensor,
159-
rope: Tuple[torch.Tensor, torch.Tensor],
160-
mask: List[torch.Tensor],
161-
input_pos: torch.Tensor,
162-
kv_cache: kv_utils.KVCache,
163-
export_config: Optional[model_builder.ExportConfig] = None,
164-
) -> dict[torch.Tensor, kv_utils.KVCache]:
165-
"""Forwards the model with input embeddings."""
166143
assert len(self.transformer_blocks) == len(kv_cache.caches), (
167144
"The number of transformer blocks and the number of KV cache entries"
168145
" must be the same."
169146
)
170147

171-
if self.config.embedding_scale is not None:
172-
input_embeds = input_embeds * self.config.embedding_scale
173-
x = input_embeds
174-
updated_kv_entries = []
148+
cos, sin = self.rope_cache
149+
cos = cos.index_select(0, input_pos)
150+
sin = sin.index_select(0, input_pos)
151+
152+
# token embeddings of shape (b, t, n_embd)
153+
x = self.tok_embedding(tokens)
154+
x = x * (self.config.embedding_dim**0.5)
155+
156+
updated_kv_entires = []
175157
for i, block in enumerate(self.transformer_blocks):
158+
mask = self.get_attention_mask(
159+
block.config.attn_config.attn_type, input_pos
160+
)
176161
kv_entry = kv_cache.caches[i] if kv_cache else None
177-
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
178163
if kv_entry:
179-
updated_kv_entries.append(kv_entry)
180-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
164+
updated_kv_entires.append(kv_entry)
165+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
181166

182167
if export_config is not None:
183168
if (
@@ -243,13 +228,11 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
243228
)
244229

245230
num_layers = 26
246-
embedding_dim = 2304
247231
config = cfg.ModelConfig(
248232
vocab_size=256000,
249233
num_layers=num_layers,
250234
max_seq_len=8192,
251-
embedding_dim=embedding_dim,
252-
embedding_scale=embedding_dim**0.5,
235+
embedding_dim=2304,
253236
kv_cache_max_len=kv_cache_max_len,
254237
block_configs=[get_block_config(i) for i in range(num_layers)],
255238
final_norm_config=norm_config,
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
266249
config.num_layers = 2
267250
config.max_seq_len = 2 * kv_cache_max_len
268251
config.embedding_dim = 128
269-
config.embedding_scale = config.embedding_dim**0.5
270252
config.block_configs = config.block_configs[: config.num_layers]
271253
for block_config in config.block_configs:
272254
block_config.attn_config.num_heads = 4

ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ def forward(
7272
mask = self.mask_cache.index_select(2, input_pos)
7373
mask = mask[:, :, :, : self.config.max_seq_len]
7474

75-
updated_kv_entries = []
75+
updated_kv_entires = []
7676
for i, block in enumerate(self.transformer_blocks):
7777
kv_entry = kv_cache.caches[i] if kv_cache else None
7878
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
7979
if kv_entry:
80-
updated_kv_entries.append(kv_entry)
80+
updated_kv_entires.append(kv_entry)
8181

82-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
82+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
8383

8484
if export_config is not None:
8585
if (

ai_edge_torch/generative/layers/attention.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,33 @@
2626
from torch import nn
2727

2828

29+
def _embed_rope(
30+
q: torch.Tensor,
31+
k: torch.Tensor,
32+
n_elem: int,
33+
rope: Tuple[torch.Tensor, torch.Tensor],
34+
) -> Tuple[torch.Tensor, torch.Tensor]:
35+
"""Embed rotary positional embedding for query and key.
36+
37+
Args:
38+
q (torch.Tensor): query tensor.
39+
k (torch.Tensor): key tensor.
40+
n_elem (int): number of elements to embed rotarty positional embedding.
41+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42+
"""
43+
if n_elem > 0:
44+
cos, sin = rope
45+
q_roped = rotary_pos_emb.apply_rope(
46+
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47+
)
48+
k_roped = rotary_pos_emb.apply_rope(
49+
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50+
)
51+
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52+
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53+
return q, k
54+
55+
2956
class TransformerBlock(nn.Module):
3057

3158
def __init__(
@@ -211,8 +238,7 @@ def forward(
211238
if rope is not None:
212239
# Compute rotary positional embedding for query and key.
213240
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
214-
cos, sin = rope
215-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
241+
q, k = _embed_rope(q, k, n_elem, rope)
216242

217243
if kv_cache is not None:
218244
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -348,8 +374,7 @@ def forward(
348374
if rope is not None:
349375
# Compute rotary positional embedding for query and key.
350376
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
351-
cos, sin = rope
352-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
377+
q, k = _embed_rope(q, k, n_elem, rope)
353378

354379
if kv_cache is not None:
355380
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)

ai_edge_torch/generative/layers/rotary_position_embedding.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,64 +32,57 @@ def apply_rope(
3232
"""
3333
x = x.transpose(1, 2)
3434
head_size = x.size(-1)
35-
x1, x2 = torch.split(x, head_size // 2, dim=-1)
36-
left = x1 * cos - x2 * sin
37-
right = x2 * cos + x1 * sin
38-
roped = torch.cat([left, right], dim=-1)
35+
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36+
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37+
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38+
roped = (x * cos) + (rotated * sin)
3939
return roped.transpose(1, 2).type_as(x)
4040

4141

42-
def build_rope(
42+
def apply_rope_inline(
43+
q: torch.Tensor,
44+
k: torch.Tensor,
4345
input_pos: torch.Tensor,
4446
n_elem: int,
45-
head_dim: int,
4647
base: int = 10_000,
4748
) -> Tuple[torch.Tensor, torch.Tensor]:
48-
"""Computes rotary positional embedding cosine and sine tensors.
49+
"""Computes rotary positional embedding inline for a query and key.
4950
5051
Args:
52+
q: the query tensor.
53+
k: the key tensor.
5154
input_pos: the sequence indices for the query and key
5255
n_elem: number of elements of the head dimension for RoPE computation
53-
base: the base of the exponentiated value for RoPE.
5456
5557
Returns:
56-
cos, sin tensors
58+
output the RoPE'd query and key.
5759
"""
5860

5961
if n_elem <= 0:
60-
return None, None
62+
return q, k
6163

6264
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
6365
freq_exponents = (2.0 / n_elem) * torch.arange(
64-
head_dim // 2, dtype=torch.float32
66+
q.shape[-1] // 2, dtype=torch.float32
6567
)
6668
timescale = float(base) ** freq_exponents
6769
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
6870
0
6971
).unsqueeze(0)
70-
cos = torch.cos(radians)
71-
sin = torch.sin(radians)
72-
return cos, sin
73-
72+
cos = torch.cos(radians).type_as(q)
73+
sin = torch.sin(radians).type_as(q)
7474

75-
def apply_rope_inline(
76-
q: torch.Tensor,
77-
k: torch.Tensor,
78-
cos: torch.Tensor,
79-
sin: torch.Tensor,
80-
) -> Tuple[torch.Tensor, torch.Tensor]:
81-
"""Computes rotary positional embedding inline for a query and key.
82-
83-
Args:
84-
q: the query tensor.
85-
k: the key tensor.
86-
cos: the cosine tensor.
87-
sin: the sine tensor.
88-
89-
Returns:
90-
output the RoPE'd query and key.
91-
"""
75+
def apply(x, sin, cos):
76+
x = x.transpose(1, 2)
77+
b, h, s, d = x.shape
78+
ans = torch.split(x, d // 2, dim=-1)
79+
x1, x2 = ans
80+
left = x1 * cos - x2 * sin
81+
right = x2 * cos + x1 * sin
82+
res = torch.cat([left, right], dim=-1)
83+
res = res.transpose(1, 2)
84+
return res
9285

93-
q_roped = apply_rope(q, cos, sin)
94-
k_roped = apply_rope(k, cos, sin)
86+
q_roped = apply(q, sin, cos)
87+
k_roped = apply(k, sin, cos)
9588
return q_roped, k_roped

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2525
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2626
import ai_edge_torch.generative.layers.model_config as cfg
27-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2827
import ai_edge_torch.generative.utilities.loader as loading_utils
2928
import torch
3029
from torch import nn
@@ -86,6 +85,13 @@ def __init__(self, config: cfg.ModelConfig):
8685
config.embedding_dim,
8786
config.final_norm_config,
8887
)
88+
# ROPE parameters for all attn_configs are the same. Take the first one.
89+
attn_config = config.block_config(0).attn_config
90+
self.rope_cache = attn_utils.build_rope_cache(
91+
size=config.kv_cache_max,
92+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93+
base=attn_config.rotary_base,
94+
)
8995
self.mask_cache = attn_utils.build_causal_mask_cache(
9096
size=config.kv_cache_max,
9197
)
@@ -107,22 +113,16 @@ def forward(
107113

108114
# token embeddings of shape (b, t, n_embd)
109115
input_embeds = self.tok_embedding(tokens)
110-
111-
# ROPE parameters for all attn_configs are the same. Take the first one.
112-
attn_config = self.config.block_config(0).attn_config
113-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
114-
rope = rotary_pos_emb.build_rope(
115-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
116-
)
117-
116+
cos, sin = self.rope_cache
117+
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
118118
mask = self.mask_cache.index_select(2, input_pos)
119119
mask = mask[:, :, :, : self.config.kv_cache_max]
120120

121-
return self._forward_with_embeds(
121+
return self.forward_with_embeds(
122122
input_embeds, rope, mask, input_pos, kv_cache, export_config
123123
)
124124

125-
def _forward_with_embeds(
125+
def forward_with_embeds(
126126
self,
127127
input_embeds: torch.Tensor,
128128
rope: Tuple[torch.Tensor, torch.Tensor],
@@ -141,13 +141,13 @@ def _forward_with_embeds(
141141
if self.config.embedding_scale is not None:
142142
x = x * self.config.embedding_scale
143143

144-
updated_kv_entries = []
144+
updated_kv_entires = []
145145
for i, block in enumerate(self.transformer_blocks):
146146
kv_entry = kv_cache.caches[i] if kv_cache else None
147147
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
148148
if kv_entry:
149-
updated_kv_entries.append(kv_entry)
150-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
149+
updated_kv_entires.append(kv_entry)
150+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
151151

152152
if export_config is not None:
153153
if (

0 commit comments

Comments
 (0)