Skip to content

Commit 562c93d

Browse files
ai-edge-botcopybara-github
authored andcommitted
Convert PaliGemma2
- Set version 2 as default in conversion and verifications - Reduce parameters of fake model configs - PaliGemma unittests are still crashing <50% times which has been improved from crashing 70% times before. Reducing some parameters seem helpful. - Updated README.md accordingly PiperOrigin-RevId: 708838151
1 parent fa6b74d commit 562c93d

File tree

8 files changed

+55
-22
lines changed

8 files changed

+55
-22
lines changed

ai_edge_torch/generative/examples/README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ Gemma is Google's open-source LLM. The model has both a 2B and 7B versions. See
77
## PaliGemma
88
PaliGemma is a multimodal LLM which gets images and text as input, then
99
generates text as output. See
10-
[model's Kaggle page](https://www.kaggle.com/models/google/paligemma).
11-
The example we provide is PaliGemma 3B with 224 image size. Since Kaggle has
12-
only Jax-version of PaliGemma, PyTorch model can be download from
13-
[here](https://huggingface.co/google/paligemma-3b-mix-224/tree/main).
14-
15-
Note that PaliGemma can be converted to TfLite only with [ODML Torch conversion
10+
[model's Kaggle page](https://www.kaggle.com/models/google/paligemma2).
11+
The examples we provide are PaliGemma2 and 1 of 3B with 224 image size.
12+
The checkpoint for PaliGemma2 can be downloaded from
13+
[here](https://www.kaggle.com/models/google/paligemma-2/transformers/paligemma2-3b-pt-224).
14+
Since Kaggle has only Jax-version of PaliGemma1, PyTorch model of PaliGemma1 can
15+
be download from [here](https://huggingface.co/google/paligemma-3b-mix-224/tree/main).
16+
17+
Note that PaliGemma models can be converted to TfLite only with [ODML Torch conversion
1618
backend](https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental)
1719

1820
## Llama

ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,15 @@
2929
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
3030
import torch
3131

32+
_VERSION = flags.DEFINE_enum(
33+
'version',
34+
'2',
35+
['1', '2'],
36+
'The version of PaliGemma model to verify.',
37+
)
3238
_CHECKPOINT_PATH = flags.DEFINE_string(
3339
'checkpoint_path',
34-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
40+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
3541
'The path to the model checkpoint, or directory holding the checkpoint.',
3642
)
3743
_TFLITE_PATH = flags.DEFINE_string(
@@ -63,10 +69,12 @@
6369

6470
def main(_):
6571
pytorch_model = paligemma.build_model(
66-
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
72+
_CHECKPOINT_PATH.value,
73+
version=int(_VERSION.value),
74+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
6775
)
6876
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69-
output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
77+
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
7078
converter.convert_to_tflite(
7179
pytorch_model,
7280
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
137137
config.vocab_size = 128
138138
config.num_layers = 2
139139
config.max_seq_len = 2 * kv_cache_max_len
140+
config.embedding_dim = 128
141+
config.embedding_scale = 128**0.5
140142
return config
141143

142144

ai_edge_torch/generative/examples/paligemma/decoder2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
160160
config.vocab_size = 128
161161
config.num_layers = 2
162162
config.max_seq_len = 2 * kv_cache_max_len
163+
config.embedding_dim = 128
164+
config.embedding_scale = 128**0.5
163165
return config
164166

165167

ai_edge_torch/generative/examples/paligemma/paligemma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
136136
return PaliGemmaConfig(
137137
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
138138
decoder_config=get_decoder_config(**kwargs),
139-
image_token_id=257152,
140-
image_projection_scale=2048**0.5,
139+
image_token_id=127,
140+
image_projection_scale=128**0.5,
141141
image_projection_use_bias=True,
142142
)
143143

ai_edge_torch/generative/examples/paligemma/verify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
_VERSION = flags.DEFINE_enum(
3232
"version",
33-
"1",
33+
"2",
3434
["1", "2"],
3535
"The version of PaliGemma model to verify.",
3636
)

ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
_VERSION = flags.DEFINE_enum(
3030
"version",
31-
"1",
31+
"2",
3232
["1", "2"],
3333
"The version of PaliGemma vision model to verify.",
3434
)

ai_edge_torch/generative/test/test_model_conversion_large.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from ai_edge_torch.generative.examples.gemma import gemma2
2222
from ai_edge_torch.generative.examples.llama import llama
2323
from ai_edge_torch.generative.examples.openelm import openelm
24+
from ai_edge_torch.generative.examples.paligemma import decoder
25+
from ai_edge_torch.generative.examples.paligemma import decoder2
2426
from ai_edge_torch.generative.examples.paligemma import paligemma
2527
from ai_edge_torch.generative.examples.phi import phi2
2628
from ai_edge_torch.generative.examples.phi import phi3
@@ -171,13 +173,9 @@ def test_amd_llama_135m(self):
171173
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
172174
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
173175

174-
@googletest.skipIf(
175-
ai_edge_torch.config.in_oss,
176-
reason="tests with custom ops are not supported in oss",
177-
)
178-
def disabled_test_paligemma(self):
179-
config = paligemma.get_fake_model_config()
180-
pytorch_model = paligemma.PaliGemma(config).eval()
176+
def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
177+
config = paligemma.get_fake_model_config(decoder_config)
178+
pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
181179

182180
image_embedding_config = config.image_encoder_config.image_embedding
183181
num_patches = (
@@ -215,11 +213,32 @@ def disabled_test_paligemma(self):
215213
kv,
216214
pixel_values=pixel_values,
217215
signature_name="prefill_pixel",
218-
atol=1e-3,
219-
rtol=1e-5,
216+
atol=atol,
217+
rtol=rtol,
220218
)
221219
)
222220

221+
@googletest.skipIf(
222+
ai_edge_torch.config.in_oss,
223+
reason="tests with custom ops are not supported in oss",
224+
)
225+
def disabled_test_paligemma1(self):
226+
self._test_paligemma_model(
227+
decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
228+
)
229+
230+
@googletest.skipIf(
231+
ai_edge_torch.config.in_oss,
232+
reason="tests with custom ops are not supported in oss",
233+
)
234+
def disabled_test_paligemma2(self):
235+
self._test_paligemma_model(
236+
decoder2.Decoder2,
237+
decoder2.get_fake_decoder2_config,
238+
atol=1e-3,
239+
rtol=1e-5,
240+
)
241+
223242
@googletest.skipIf(
224243
ai_edge_torch.config.in_oss,
225244
reason="tests with custom ops are not supported in oss",

0 commit comments

Comments
 (0)