Skip to content

Commit 21d2732

Browse files
ai-edge-botcopybara-github
authored andcommitted
Pass mask and export config correctly in PaliGemma's decoder.
PiperOrigin-RevId: 722781201
1 parent b43a14e commit 21d2732

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2121
import ai_edge_torch.generative.layers.model_config as cfg
22-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2322
from ai_edge_torch.generative.utilities import model_builder
2423
import ai_edge_torch.generative.utilities.loader as loading_utils
2524
import torch
@@ -59,25 +58,32 @@ def forward(
5958
called_by_generate: bool = True,
6059
) -> dict[torch.Tensor, kv_utils.KVCache]:
6160
if input_embeds is None:
62-
return super().forward(tokens, input_pos, kv_cache)
61+
return super().forward(
62+
tokens, input_pos, kv_cache, mask, export_config=export_config
63+
)
6364

6465
assert input_embeds is not None
6566

6667
repo_pos = input_pos + 1 # PaliGemma position is 1-based.
6768
# ROPE parameters for all attn_configs are the same. Take the first one.
6869
attn_config = self.config.block_config(0).attn_config
6970
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
70-
rope = rotary_pos_emb.build_rope(repo_pos, n_elem, attn_config.rotary_base)
71+
rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
7172

7273
# The first part of input_embeds are image embeddings. Diagonal causal mask
7374
# doesn't work here.
74-
embeds_len = input_embeds.shape[1]
7575
if mask is None:
76+
embeds_len = input_embeds.shape[1]
7677
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
7778
mask[:, embeds_len:] = float("-inf")
7879

7980
return self._forward_with_embeds(
80-
input_embeds, rope, mask, input_pos, kv_cache
81+
input_embeds,
82+
rope,
83+
mask,
84+
input_pos,
85+
kv_cache,
86+
export_config=export_config,
8187
)
8288

8389

ai_edge_torch/generative/examples/paligemma/decoder2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from ai_edge_torch.generative.examples.gemma import gemma2
2121
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2222
import ai_edge_torch.generative.layers.model_config as cfg
23-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
2423
from ai_edge_torch.generative.utilities import model_builder
2524
import ai_edge_torch.generative.utilities.loader as loading_utils
2625
import torch
@@ -62,19 +61,20 @@ def forward(
6261
called_by_generate: bool = True,
6362
) -> dict[torch.Tensor, kv_utils.KVCache]:
6463
if input_embeds is None:
65-
return super().forward(tokens, input_pos, kv_cache)
64+
return super().forward(tokens, input_pos, kv_cache, mask, export_config)
6665

6766
assert input_embeds is not None
6867

6968
repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
7069
# ROPE parameters for all attn_configs are the same. Take the first one.
7170
attn_config = self.config.block_config(0).attn_config
7271
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
73-
rope = rotary_pos_emb.build_rope(repo_pos, n_elem, attn_config.rotary_base)
72+
rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
7473

7574
if mask is None:
7675
if called_by_generate:
77-
# PaliGemma2 generate() use a diagonal causal mask even with image embeds.
76+
# PaliGemma2 generate() uses a diagonal causal mask even with image
77+
# embeds.
7878
mask = [
7979
self.get_attention_mask(
8080
self.config.block_config(i).attn_config.attn_type, input_pos

ai_edge_torch/generative/examples/paligemma/image_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, config: cfg.ModelConfig):
6060
kernel_size=config.image_embedding.patch_size,
6161
stride=config.image_embedding.patch_size,
6262
padding=0,
63-
use_bias=config.embedding_use_bias,
63+
bias=config.embedding_use_bias,
6464
)
6565
num_patches = (
6666
config.image_embedding.image_size // config.image_embedding.patch_size

0 commit comments

Comments
 (0)