Skip to content

Commit eafcb12

Browse files
ai-edge-botcopybara-github
authored andcommitted
Fix rope calculation of PaliGemma decoder when pixel_embeds is passed.
PiperOrigin-RevId: 708404348
1 parent 9d387ec commit eafcb12

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2121
import ai_edge_torch.generative.layers.model_config as cfg
22+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2223
from ai_edge_torch.generative.utilities import model_builder
2324
import ai_edge_torch.generative.utilities.loader as loading_utils
2425
import torch
@@ -61,8 +62,12 @@ def forward(
6162
assert input_embeds is not None
6263

6364
repo_pos = input_pos + 1 # PaliGemma position is 1-based.
64-
cos, sin = self.rope_cache
65-
rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
65+
# ROPE parameters for all attn_configs are the same. Take the first one.
66+
attn_config = self.config.block_config(0).attn_config
67+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
68+
rope = rotary_pos_emb.build_rope(
69+
repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
70+
)
6671

6772
# The first part of input_embeds are image embeddings. Diagonal causal mask
6873
# doesn't work here.

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ def forward(
107107

108108
# token embeddings of shape (b, t, n_embd)
109109
input_embeds = self.tok_embedding(tokens)
110-
mask = self.mask_cache.index_select(2, input_pos)
111-
mask = mask[:, :, :, : self.config.kv_cache_max]
112110

113111
# ROPE parameters for all attn_configs are the same. Take the first one.
114112
attn_config = self.config.block_config(0).attn_config
@@ -117,6 +115,9 @@ def forward(
117115
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
118116
)
119117

118+
mask = self.mask_cache.index_select(2, input_pos)
119+
mask = mask[:, :, :, : self.config.kv_cache_max]
120+
120121
return self.forward_with_embeds(
121122
input_embeds, rope, mask, input_pos, kv_cache, export_config
122123
)

0 commit comments

Comments
 (0)