Skip to content

Commit 85446ef

Browse files
talumbaucopybara-github
authored andcommitted
Build RoPE Inline with configurable function
PiperOrigin-RevId: 713321103
1 parent 17a87a5 commit 85446ef

File tree

8 files changed

+164
-147
lines changed

8 files changed

+164
-147
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515

1616
"""Example of building a Gemma2 model."""
1717

18-
from typing import Optional, Tuple
18+
from typing import List, Optional, Tuple
1919

2020
from ai_edge_torch.generative.layers import attention
2121
from ai_edge_torch.generative.layers import builder
2222
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2323
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
25+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2526
from ai_edge_torch.generative.utilities import model_builder
2627
import ai_edge_torch.generative.utilities.loader as loading_utils
2728
import torch
@@ -103,17 +104,12 @@ def __init__(self, config: cfg.ModelConfig):
103104
config.embedding_dim,
104105
config.final_norm_config,
105106
)
106-
# Gemma2 has same hyper parameters for each layer except for attention
107-
# types. Use the first layer.
108-
attn_config = config.block_config(0).attn_config
109-
self.rope_cache = attn_utils.build_rope_cache(
110-
size=config.kv_cache_max,
111-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112-
base=attn_config.rotary_base,
113-
)
114107
self.mask_cache = attn_utils.build_causal_mask_cache(
115108
size=config.kv_cache_max,
116109
)
110+
# Gemma2 has same hyper parameters for each layer except for attention
111+
# types. Use the first layer.
112+
attn_config = config.block_config(0).attn_config
117113
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118114
size=config.kv_cache_max,
119115
window_size=attn_config.sliding_window_size,
@@ -140,29 +136,51 @@ def forward(
140136
f"Cannot forward sequence of length {seq_len}, max seq length is only"
141137
f" {self.config.max_seq_len}"
142138
)
139+
140+
# token embeddings of shape (b, t, n_embd)
141+
input_embeds = self.tok_embedding(tokens)
142+
# RoPE parameters are the same for all blocks. Use the first layer.
143+
attn_config = self.config.block_config(0).attn_config
144+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145+
rope = rotary_pos_emb.build_rope(
146+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147+
)
148+
mask = [
149+
self.get_attention_mask(
150+
self.config.block_config(i).attn_config.attn_type, input_pos
151+
)
152+
for i in range(self.config.num_layers)
153+
]
154+
155+
return self._forward_with_embeds(
156+
input_embeds, rope, mask, input_pos, kv_cache, export_config
157+
)
158+
159+
def _forward_with_embeds(
160+
self,
161+
input_embeds: torch.Tensor,
162+
rope: Tuple[torch.Tensor, torch.Tensor],
163+
mask: List[torch.Tensor],
164+
input_pos: torch.Tensor,
165+
kv_cache: kv_utils.KVCache,
166+
export_config: Optional[model_builder.ExportConfig] = None,
167+
) -> dict[torch.Tensor, kv_utils.KVCache]:
168+
"""Forwards the model with input embeddings."""
143169
assert len(self.transformer_blocks) == len(kv_cache.caches), (
144170
"The number of transformer blocks and the number of KV cache entries"
145171
" must be the same."
146172
)
147173

148-
cos, sin = self.rope_cache
149-
cos = cos.index_select(0, input_pos)
150-
sin = sin.index_select(0, input_pos)
151-
152-
# token embeddings of shape (b, t, n_embd)
153-
x = self.tok_embedding(tokens)
154-
x = x * (self.config.embedding_dim**0.5)
155-
156-
updated_kv_entires = []
174+
if self.config.embedding_scale is not None:
175+
input_embeds = input_embeds * self.config.embedding_scale
176+
x = input_embeds
177+
updated_kv_entries = []
157178
for i, block in enumerate(self.transformer_blocks):
158-
mask = self.get_attention_mask(
159-
block.config.attn_config.attn_type, input_pos
160-
)
161179
kv_entry = kv_cache.caches[i] if kv_cache else None
162-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
180+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
163181
if kv_entry:
164-
updated_kv_entires.append(kv_entry)
165-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
182+
updated_kv_entries.append(kv_entry)
183+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166184

167185
if export_config is not None:
168186
if (
@@ -228,11 +246,13 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
228246
)
229247

230248
num_layers = 26
249+
embedding_dim = 2304
231250
config = cfg.ModelConfig(
232251
vocab_size=256000,
233252
num_layers=num_layers,
234253
max_seq_len=8192,
235-
embedding_dim=2304,
254+
embedding_dim=embedding_dim,
255+
embedding_scale=embedding_dim**0.5,
236256
kv_cache_max_len=kv_cache_max_len,
237257
block_configs=[get_block_config(i) for i in range(num_layers)],
238258
final_norm_config=norm_config,
@@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
249269
config.num_layers = 2
250270
config.max_seq_len = 2 * kv_cache_max_len
251271
config.embedding_dim = 128
272+
config.embedding_scale = config.embedding_dim**0.5
252273
config.block_configs = config.block_configs[: config.num_layers]
253274
for block_config in config.block_configs:
254275
block_config.attn_config.num_heads = 4

ai_edge_torch/generative/examples/llama/llama.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Example of building Llama 3.2 models."""
1717

18+
from functools import partial
1819
import math
1920
from typing import Tuple
2021

@@ -26,8 +27,8 @@
2627

2728

2829
def _build_llama3_rope_cache(
29-
size: int,
30-
dim: int,
30+
input_pos: torch.Tensor,
31+
n_elem: int,
3132
base: int,
3233
condense_ratio: int,
3334
dtype: torch.dtype,
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
3637
low_freq_factor: float,
3738
high_freq_factor: float,
3839
max_seq_len: int,
40+
**kwargs,
3941
) -> Tuple[torch.Tensor, torch.Tensor]:
40-
"""Precomputes Rotary Positional Embeddings for Llama 3.2 model.
42+
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
4143
4244
It's a modified version of attn_utils.build_rope_cache with additional
4345
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
4749
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
4850
4951
Args:
50-
size (int): The size of the built cache.
51-
dim (int): Each sequence's dimmension.
52-
base (int, optional): Rope base value.
53-
condense_ratio (int, optional): The ratio by which sequence indicies are
54-
condensed.
55-
dtype (torch.dtype, optional): Output tensor's data type.
56-
device (torch.device, optional): Output tensor's data type.
52+
input_pos (torch.Tensor): the given input sequence positions
53+
n_elem (int): Each sequence's dimmension.
54+
base (int): Rope base value.
55+
condense_ratio (int): The ratio by which sequence indicies are condensed.
56+
dtype (torch.dtype): Output tensor's data type.
57+
device (torch.device): Output tensor's data type.
5758
factor (float): Factor to scale theta down for tokens in long range in the
5859
sequence.
5960
low_freq_factor (float): Factor to determine if tokens are in long range
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
6667
Returns:
6768
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
6869
"""
69-
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
70+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
7071
low_freq_wavelen = max_seq_len / low_freq_factor
7172
high_freq_wavelen = max_seq_len / high_freq_factor
7273
wavelen = 2 * math.pi / theta
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
8182
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
8283
theta = torch.where(is_medium, smoothed_theta, theta)
8384

84-
seq_idx = torch.arange(size) / condense_ratio
85+
seq_idx = input_pos / condense_ratio
8586
idx_theta = torch.outer(seq_idx, theta)
8687
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
8788
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
9798
def __init__(self, config: cfg.ModelConfig):
9899
super().__init__(config)
99100
attn_config = self.config.block_config(0).attn_config
100-
self.rope_cache = _build_llama3_rope_cache(
101-
size=self.config.kv_cache_max,
102-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
103-
base=attn_config.rotary_base,
104-
condense_ratio=1,
105-
dtype=torch.float32,
106-
device=torch.device("cpu"),
107-
factor=32.0,
108-
low_freq_factor=1.0,
109-
high_freq_factor=4.0,
110-
max_seq_len=self.config.max_seq_len,
111-
)
112101

113102

114103
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140129
pre_attention_norm_config=norm_config,
141130
post_attention_norm_config=norm_config,
142131
)
132+
133+
max_seq_len = 8192
134+
# Create the RoPE callable
135+
build_rope = partial(
136+
_build_llama3_rope_cache,
137+
condense_ratio=1,
138+
dtype=torch.float32,
139+
device=torch.device("cpu"),
140+
factor=32.0,
141+
low_freq_factor=1.0,
142+
high_freq_factor=4.0,
143+
max_seq_len=max_seq_len,
144+
)
145+
143146
config = cfg.ModelConfig(
144147
vocab_size=128256,
145148
num_layers=16,
146-
max_seq_len=8192,
149+
max_seq_len=max_seq_len,
147150
embedding_dim=2048,
148151
kv_cache_max_len=kv_cache_max_len,
149152
block_configs=block_config,
150153
final_norm_config=norm_config,
151154
enable_hlfb=True,
155+
build_rope=build_rope,
152156
)
153157
return config
154158

ai_edge_torch/generative/examples/phi/phi3.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
1717

18+
from functools import partial
1819
import math
1920
from typing import Tuple
2021

@@ -93,40 +94,41 @@
9394
]
9495

9596

96-
def _build_rope_cache(
97-
size: int,
98-
dim: int,
97+
def _build_phi3_rope(
98+
input_pos: int,
99+
n_elem: int,
99100
base: int,
100101
condense_ratio: int,
101102
dtype: torch.dtype,
102103
device: torch.device,
103104
theta_factors: torch.Tensor,
104105
scale: float,
106+
**kwargs,
105107
) -> Tuple[torch.Tensor, torch.Tensor]:
106-
"""Precomputes Rotary Positional Embeddings for Phi-3.5 model.
108+
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
107109
108110
It's a modified version of attn_utils.build_rope_cache with additional
109111
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
110112
Cos values with scaling factors for quick lookup during the inference.
111113
112114
Args:
113-
size (int): The size of the built cache.
114-
dim (int): Each sequence's dimmension.
115+
input_pos (torch.Tensor): the given input sequence positions
116+
n_elem (int): Each sequence's dimmension.
115117
base (int, optional): Rope base value.
116118
condense_ratio (int, optional): The ratio by which sequence indicies are
117119
condensed.
118120
dtype (torch.dtype, optional): Output tensor's data type.
119121
device (torch.device, optional): Output tensor's data type.
120-
theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
121-
scale the theta values.
122+
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
123+
to scale the theta values.
122124
scale (float, optional): A float used to scale the rope values.
123125
124126
Returns:
125127
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
126128
"""
127-
theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
129+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
128130
theta = theta / theta_factors
129-
seq_idx = torch.arange(size) / condense_ratio
131+
seq_idx = input_pos / condense_ratio
130132
idx_theta = torch.outer(seq_idx, theta)
131133
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
132134
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
139141
def __init__(self, config: cfg.ModelConfig):
140142
super().__init__(config)
141143
attn_config = self.config.block_config(0).attn_config
142-
self.rope_cache = _build_rope_cache(
143-
size=self.config.kv_cache_max,
144-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
145-
base=attn_config.rotary_base,
146-
condense_ratio=1,
147-
dtype=torch.float32,
148-
device=torch.device("cpu"),
149-
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
150-
scale=math.sqrt(
151-
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
152-
),
153-
)
154144

155145

156146
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
183173
pre_attention_norm_config=norm_config,
184174
post_attention_norm_config=norm_config,
185175
)
176+
max_seq_len = 4096
177+
# Create the RoPE callable
178+
build_rope = partial(
179+
_build_phi3_rope,
180+
condense_ratio=1,
181+
dtype=torch.float32,
182+
device=torch.device("cpu"),
183+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
184+
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
185+
max_seq_len=max_seq_len,
186+
)
187+
186188
config = cfg.ModelConfig(
187189
vocab_size=32064,
188190
num_layers=32,
189-
max_seq_len=4096,
191+
max_seq_len=max_seq_len,
190192
kv_cache_max_len=kv_cache_max_len,
191193
embedding_dim=3072,
192194
block_configs=block_config,
193195
final_norm_config=norm_config,
194196
lm_head_share_weight_with_embedding=False,
195197
enable_hlfb=True,
198+
build_rope=build_rope,
196199
)
197200
return config
198201

ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ def forward(
7272
mask = self.mask_cache.index_select(2, input_pos)
7373
mask = mask[:, :, :, : self.config.max_seq_len]
7474

75-
updated_kv_entires = []
75+
updated_kv_entries = []
7676
for i, block in enumerate(self.transformer_blocks):
7777
kv_entry = kv_cache.caches[i] if kv_cache else None
7878
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
7979
if kv_entry:
80-
updated_kv_entires.append(kv_entry)
80+
updated_kv_entries.append(kv_entry)
8181

82-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
82+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
8383

8484
if export_config is not None:
8585
if (

0 commit comments

Comments
 (0)