Skip to content

Commit c8d0b19

Browse files
ai-edge-botcopybara-github
authored andcommitted
Fix broken Phi2 and PaliGemma1
- Phi2 does RoPE partially (rotary_percentage=0.4) - First 40% input tokens must be roped while the rest must be untouched - PaliGemma1 decoder expects _forward_with_embeds - Verified that all example's verify.py is now passing PiperOrigin-RevId: 715167410
1 parent 58f1c71 commit c8d0b19

File tree

7 files changed

+14
-29
lines changed

7 files changed

+14
-29
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def forward(
143143
# RoPE parameters are the same for all blocks. Use the first layer.
144144
attn_config = self.config.block_config(0).attn_config
145145
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
146-
rope = rotary_pos_emb.build_rope(
147-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
148-
)
146+
rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
149147
mask = [
150148
self.get_attention_mask(
151149
self.config.block_config(i).attn_config.attn_type, input_pos

ai_edge_torch/generative/examples/llama/llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def _build_llama3_rope_cache(
3737
low_freq_factor: float,
3838
high_freq_factor: float,
3939
max_seq_len: int,
40-
**kwargs,
4140
) -> Tuple[torch.Tensor, torch.Tensor]:
4241
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
4342

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def forward(
6767
# ROPE parameters for all attn_configs are the same. Take the first one.
6868
attn_config = self.config.block_config(0).attn_config
6969
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
70-
rope = rotary_pos_emb.build_rope(
71-
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
72-
)
70+
rope = rotary_pos_emb.build_rope(repo_pos, n_elem, attn_config.rotary_base)
7371

7472
# The first part of input_embeds are image embeddings. Diagonal causal mask
7573
# doesn't work here.

ai_edge_torch/generative/examples/paligemma/decoder2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def forward(
7070
# ROPE parameters for all attn_configs are the same. Take the first one.
7171
attn_config = self.config.block_config(0).attn_config
7272
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
73-
rope = rotary_pos_emb.build_rope(
74-
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
75-
)
73+
rope = rotary_pos_emb.build_rope(repo_pos, n_elem, attn_config.rotary_base)
7674

7775
if mask is None:
7876
if called_by_generate:

ai_edge_torch/generative/examples/phi/phi3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def _build_phi3_rope(
103103
device: torch.device,
104104
theta_factors: torch.Tensor,
105105
scale: float,
106-
**kwargs,
107106
) -> Tuple[torch.Tensor, torch.Tensor]:
108107
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
109108
@@ -173,6 +172,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
173172
pre_attention_norm_config=norm_config,
174173
post_attention_norm_config=norm_config,
175174
)
175+
176176
max_seq_len = 4096
177177
# Create the RoPE callable
178178
build_rope = partial(
@@ -182,7 +182,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
182182
device=torch.device("cpu"),
183183
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
184184
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
185-
max_seq_len=max_seq_len,
186185
)
187186

188187
config = cfg.ModelConfig(

ai_edge_torch/generative/layers/rotary_position_embedding.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
# Implementation for Rotary Position embedding. https://arxiv.org/pdf/2104.09864.pdf
16+
1617
from typing import Tuple
1718
import torch
1819

@@ -31,18 +32,17 @@ def apply_rope(
3132
output tensor of RoPE.
3233
"""
3334
x = x.transpose(1, 2)
34-
head_size = x.size(-1)
35-
x1, x2 = torch.split(x, head_size // 2, dim=-1)
36-
left = x1 * cos - x2 * sin
37-
right = x2 * cos + x1 * sin
38-
roped = torch.cat([left, right], dim=-1)
35+
rope_size = cos.size(-1)
36+
x_splited = torch.split(x, rope_size, dim=-1)
37+
left = x_splited[0] * cos - x_splited[1] * sin
38+
right = x_splited[1] * cos + x_splited[0] * sin
39+
roped = torch.cat((left, right) + x_splited[2:], dim=-1)
3940
return roped.transpose(1, 2).type_as(x)
4041

4142

4243
def build_rope(
4344
input_pos: torch.Tensor,
4445
n_elem: int,
45-
head_dim: int,
4646
base: int = 10_000,
4747
) -> Tuple[torch.Tensor, torch.Tensor]:
4848
"""Computes rotary positional embedding cosine and sine tensors.
@@ -60,7 +60,7 @@ def build_rope(
6060
return None, None
6161

6262
freq_exponents = (2.0 / n_elem) * torch.arange(
63-
head_dim // 2, dtype=torch.float32
63+
n_elem // 2, dtype=torch.float32
6464
)
6565
timescale = float(base) ** freq_exponents
6666
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from ai_edge_torch.generative.layers import lora as lora_utils
2626
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2727
import ai_edge_torch.generative.layers.model_config as cfg
28-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2928
import ai_edge_torch.generative.utilities.loader as loading_utils
3029
import torch
3130
from torch import nn
@@ -115,23 +114,17 @@ def forward(
115114
# ROPE parameters for all attn_configs are the same. Take the first one.
116115
attn_config = self.config.block_config(0).attn_config
117116
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
118-
rope = self.config.build_rope(
119-
input_pos=input_pos,
120-
n_elem=n_elem,
121-
base=attn_config.rotary_base,
122-
head_dim=attn_config.head_dim,
123-
# input_pos=input_pos, n_elem=n_elem, base=attn_config.rotary_base
124-
)
117+
rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
125118

126119
if mask is None:
127120
mask = self.mask_cache.index_select(2, input_pos)
128121
mask = mask[:, :, :, : self.config.kv_cache_max]
129122

130-
return self.forward_with_embeds(
123+
return self._forward_with_embeds(
131124
input_embeds, rope, mask, input_pos, kv_cache, lora, export_config
132125
)
133126

134-
def forward_with_embeds(
127+
def _forward_with_embeds(
135128
self,
136129
input_embeds: torch.Tensor,
137130
rope: Tuple[torch.Tensor, torch.Tensor],

0 commit comments

Comments
 (0)