Skip to content

Commit f242ee3

Browse files
ai-edge-botcopybara-github
authored andcommitted
Set enable_hlfb to true for PaliGemma image encoder
- XNNPACK supports all-zeros mask not passed by cl/722748270. - Calculate pixel size from config, not from a command flag. - Don't assume the first dimension of pixel value is always 1 which is not the case for Qwen VL, for example. PiperOrigin-RevId: 725867638
1 parent 10c480c commit f242ee3

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@
5656
1280,
5757
'The maximum size of KV cache buffer, including both prefill and decode.',
5858
)
59-
_PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
60-
'pixel_values_size',
61-
[3, 224, 224],
62-
'The size of prefill pixel values except the batch dimension.',
63-
)
6459
_QUANTIZE = flags.DEFINE_bool(
6560
'quantize',
6661
True,
@@ -75,12 +70,15 @@ def main(_):
7570
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
7671
)
7772

73+
config = pytorch_model.image_encoder.config.image_embedding
7874
converter.convert_to_tflite(
7975
pytorch_model,
8076
output_path=_OUTPUT_PATH.value,
8177
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
8278
prefill_seq_len=_PREFILL_SEQ_LEN.value,
83-
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
79+
pixel_values_size=torch.Size(
80+
[1, config.channels, config.image_size, config.image_size]
81+
),
8482
quantize=_QUANTIZE.value,
8583
config=pytorch_model.config.decoder_config,
8684
export_config=ExportConfig(),

ai_edge_torch/generative/examples/paligemma/image_encoder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,7 @@ def get_image_encoder_config() -> cfg.ModelConfig:
136136
image_embedding=image_embedding_config,
137137
block_configs=block_config,
138138
final_norm_config=norm_config,
139-
# TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
140-
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
141-
# enable_hlfb=True,
139+
enable_hlfb=True,
142140
)
143141
return config
144142

ai_edge_torch/generative/utilities/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _export_helper(
145145
prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
146146

147147
prefill_pixel_values = (
148-
torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
148+
torch.full(pixel_values_size, 0, dtype=torch.float32)
149149
if pixel_values_size
150150
else None
151151
)

0 commit comments

Comments
 (0)