Skip to content

Commit 45f18d2

Browse files
ai-edge-botcopybara-github
authored andcommitted
Fix broken PaliGemma2
- PaliGemma2 doesn't use diagonal mask any more - Set image embedding scaling factor correctly PiperOrigin-RevId: 723233385
1 parent 21d2732 commit 45f18d2

File tree

5 files changed

+13
-43
lines changed

5 files changed

+13
-43
lines changed

ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
"""Example of converting a PaliGemma model to multi-signature tflite model.
17-
18-
DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
19-
https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
20-
"""
16+
"""Example of converting a PaliGemma model to multi-signature tflite model."""
2117

2218
import os
2319
import pathlib

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def forward(
5555
input_embeds: torch.Tensor = None,
5656
mask: Optional[torch.Tensor] = None,
5757
export_config: Optional[model_builder.ExportConfig] = None,
58-
called_by_generate: bool = True,
5958
) -> dict[torch.Tensor, kv_utils.KVCache]:
6059
if input_embeds is None:
6160
return super().forward(
@@ -64,11 +63,11 @@ def forward(
6463

6564
assert input_embeds is not None
6665

67-
repo_pos = input_pos + 1 # PaliGemma position is 1-based.
66+
rope_pos = input_pos + 1 # PaliGemma position is 1-based.
6867
# ROPE parameters for all attn_configs are the same. Take the first one.
6968
attn_config = self.config.block_config(0).attn_config
7069
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
71-
rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
70+
rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
7271

7372
# The first part of input_embeds are image embeddings. Diagonal causal mask
7473
# doesn't work here.

ai_edge_torch/generative/examples/paligemma/decoder2.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,34 +58,23 @@ def forward(
5858
input_embeds: torch.Tensor = None,
5959
mask: Optional[torch.Tensor] = None,
6060
export_config: Optional[model_builder.ExportConfig] = None,
61-
called_by_generate: bool = True,
6261
) -> dict[torch.Tensor, kv_utils.KVCache]:
6362
if input_embeds is None:
6463
return super().forward(tokens, input_pos, kv_cache, mask, export_config)
6564

6665
assert input_embeds is not None
6766

68-
repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
67+
rope_pos = input_pos + 1 # PaliGemma2 position is 1-based.
6968
# ROPE parameters for all attn_configs are the same. Take the first one.
7069
attn_config = self.config.block_config(0).attn_config
7170
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72-
rope = self.config.build_rope(repo_pos, n_elem, attn_config.rotary_base)
71+
rope = self.config.build_rope(rope_pos, n_elem, attn_config.rotary_base)
7372

7473
if mask is None:
75-
if called_by_generate:
76-
# PaliGemma2 generate() uses a diagonal causal mask even with image
77-
# embeds.
78-
mask = [
79-
self.get_attention_mask(
80-
self.config.block_config(i).attn_config.attn_type, input_pos
81-
)
82-
for i in range(self.config.num_layers)
83-
]
84-
else:
85-
# By default, don't mask image embeds with a diagonal causal mask.
86-
embeds_len = input_embeds.shape[1]
87-
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
88-
mask[:, embeds_len:] = float("-inf")
74+
# By default, don't mask image embeds with a diagonal causal mask.
75+
embeds_len = input_embeds.shape[1]
76+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
77+
mask[:, embeds_len:] = float("-inf")
8978

9079
return self._forward_with_embeds(
9180
input_embeds, rope, mask, input_pos, kv_cache, export_config

ai_edge_torch/generative/examples/paligemma/paligemma.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
"""Example of building a full-stack of PaliGemma model."""
1717

18-
from dataclasses import dataclass
18+
import dataclasses
1919
from typing import Optional
2020

2121
from ai_edge_torch.generative.examples.paligemma import decoder
@@ -31,15 +31,14 @@
3131
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
3232

3333

34-
@dataclass
34+
@dataclasses.dataclass
3535
class PaliGemmaConfig:
3636
"""PaliGemma model configurations."""
3737

3838
image_encoder_config: cfg.ModelConfig
3939
decoder_config: cfg.ModelConfig
4040

4141
image_token_id: int
42-
image_projection_scale: float
4342
image_projection_use_bias: bool = False
4443

4544

@@ -73,7 +72,6 @@ def forward(
7372
mask: Optional[torch.Tensor] = None,
7473
pixel_values: torch.Tensor = None,
7574
export_config: Optional[model_builder.ExportConfig] = None,
76-
called_by_generate: bool = True,
7775
) -> dict[torch.Tensor, kv_utils.KVCache]:
7876
if pixel_values is None:
7977
return self.decoder(
@@ -83,14 +81,13 @@ def forward(
8381
mask=mask,
8482
input_embeds=None,
8583
export_config=export_config,
86-
called_by_generate=called_by_generate,
8784
)
8885

8986
input_embeds = self.decoder.tok_embedding(tokens)
9087

9188
image_encoded = self.image_encoder(pixel_values=pixel_values)
9289
image_embeds = self.image_projection(image_encoded)
93-
image_embeds = image_embeds / self.config.image_projection_scale
90+
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
9491

9592
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
9693
# can be done like:
@@ -116,7 +113,6 @@ def forward(
116113
mask=mask,
117114
input_embeds=input_embeds,
118115
export_config=export_config,
119-
called_by_generate=called_by_generate,
120116
)
121117

122118

@@ -130,7 +126,6 @@ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
130126
image_encoder_config=image_encoder.get_image_encoder_config(),
131127
decoder_config=get_decoder_config(**kwargs),
132128
image_token_id=257152,
133-
image_projection_scale=2048**0.5,
134129
image_projection_use_bias=True,
135130
)
136131

@@ -140,7 +135,6 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
140135
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
141136
decoder_config=get_decoder_config(**kwargs),
142137
image_token_id=127,
143-
image_projection_scale=128**0.5,
144138
image_projection_use_bias=True,
145139
)
146140

ai_edge_torch/generative/examples/paligemma/verify.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
_PROMPTS = flags.DEFINE_string(
4343
"prompts",
44-
"describe en",
44+
"<image><bos>describe en",
4545
"The input prompts to generate answers.",
4646
)
4747
_MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -59,16 +59,9 @@
5959
class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
6060
"""Reauthored PaliGemma model wrapper."""
6161

62-
def __init__(self, model: torch.nn.Module):
63-
super().__init__(model)
64-
self.forward_called_by_generate = False
65-
6662
def _init_kv_cache(self):
6763
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
6864

69-
def _get_extra_args_for_forward(self):
70-
return {"called_by_generate": self.forward_called_by_generate}
71-
7265

7366
def main(_):
7467
if _VERSION.value == "1":
@@ -137,7 +130,6 @@ def main(_):
137130
logging.info("outputs_from_original_model: [[%s]]", response_original)
138131

139132
logging.info("Generating answer with the reauthored model...")
140-
wrapped_reauthored_model.forward_called_by_generate = True
141133
outputs_reauthored = wrapped_reauthored_model.generate(
142134
prompts=inputs["input_ids"],
143135
pixel_values=inputs["pixel_values"],

0 commit comments

Comments
 (0)