Skip to content

Commit 46e4c41

Browse files
ai-edge-botcopybara-github
authored andcommitted
Convert full-stack of Qwen2.4 VL model
- Pre-calculate and cache some values in image encoder which can be derived from image encoder's configurations, but hard to export with lowered ops. 1) window_index: "array index must be concrete" at index_new = index_padded[index_padded != -100] 2) revert_index: "'vhlo.sort_v1' op is not part of the vhlo support yet" at reverse_index = torch.argsort(window_index) PiperOrigin-RevId: 727000207
1 parent f242ee3 commit 46e4c41

File tree

9 files changed

+260
-85
lines changed

9 files changed

+260
-85
lines changed

ai_edge_torch/generative/examples/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ same architecture as SmolLM but it has been trained on improved training data.
5050

5151
## Qwen
5252
Alibaba's [Qwen 2.5](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)
53-
0.5B, 1B, 3B modes are also provided as examples.
53+
0.5B, 1B, 3B models are also provided as examples.
54+
55+
## Qwen VL
56+
Alibaba's [Qwen 2.5 VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
57+
3B Instruct model is also provided as an multimodal model example.
5458

5559
## DeepSeek
5660
[DeepSeek-R1 distilled](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of converting a Qwen 2.5 VL model to multi-signature tflite model."""
17+
18+
import os
19+
import pathlib
20+
21+
from absl import app
22+
from absl import flags
23+
from ai_edge_torch.generative.examples.qwen_vl import qwen_vl
24+
from ai_edge_torch.generative.utilities import converter
25+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26+
27+
_CHECKPOINT_PATH = flags.DEFINE_string(
28+
'checkpoint_path',
29+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen-vl'),
30+
'The path to the model checkpoint, or directory holding the checkpoint.',
31+
)
32+
_OUTPUT_PATH = flags.DEFINE_string(
33+
'output_path',
34+
'/tmp/',
35+
'The path to export the tflite model.',
36+
)
37+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38+
'output_name_prefix',
39+
'qwen_vl',
40+
'The prefix of the output tflite model name.',
41+
)
42+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
43+
'prefill_seq_len',
44+
1024,
45+
'The maximum size of prefill input tensor.',
46+
)
47+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
48+
'kv_cache_max_len',
49+
1280,
50+
'The maximum size of KV cache buffer, including both prefill and decode.',
51+
)
52+
_IMAGE_HEIGHT = flags.DEFINE_integer(
53+
'image_height',
54+
34 * 14,
55+
'The height of image.',
56+
)
57+
_IMAGE_WIDTH = flags.DEFINE_integer(
58+
'image_width',
59+
46 * 14,
60+
'The width of image.',
61+
)
62+
_QUANTIZE = flags.DEFINE_bool(
63+
'quantize',
64+
True,
65+
'Whether the model should be quantized.',
66+
)
67+
68+
69+
def main(_):
70+
pytorch_model = qwen_vl.build_model(
71+
_CHECKPOINT_PATH.value,
72+
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
73+
image_size=(_IMAGE_HEIGHT.value, _IMAGE_WIDTH.value),
74+
)
75+
76+
grid_thw = pytorch_model.image_encoder.get_grid_thw()
77+
converter.convert_to_tflite(
78+
pytorch_model,
79+
output_path=_OUTPUT_PATH.value,
80+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
81+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
82+
pixel_values_size=(
83+
pytorch_model.image_encoder.get_pixel_values_size(grid_thw)
84+
),
85+
quantize=_QUANTIZE.value,
86+
config=pytorch_model.config.decoder_config,
87+
export_config=ExportConfig(),
88+
)
89+
90+
91+
if __name__ == '__main__':
92+
app.run(main)

ai_edge_torch/generative/examples/qwen_vl/image_encoder.py

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""Example of building an image encoder of Qwen 2.5 VL model."""
1717

1818
import dataclasses
19-
from typing import Optional
19+
from typing import List, Optional, Tuple
2020

2121
from ai_edge_torch.generative.layers import attention
2222
from ai_edge_torch.generative.layers import attention_utils
@@ -93,7 +93,7 @@ def __init__(self, config: QwenVLImageConfig):
9393

9494
# Tensor shape used to reshape pixel_values in forward() and various places.
9595
self.kernel_size = (
96-
-1, # batch size
96+
-1, # pixel_values.size(0)
9797
config.image_embedding.channels,
9898
config.image_embedding.temporal_patch_size,
9999
config.image_embedding.patch_size,
@@ -118,28 +118,22 @@ def __init__(self, config: QwenVLImageConfig):
118118
)
119119
self.merger = QwenVLMerger(config)
120120
self.config = config
121+
self.set_image_size(config.image_embedding.image_size)
121122

122123
@torch.inference_mode
123-
def forward(
124-
self, pixel_values: torch.Tensor, grid_thw: torch.Tensor
125-
) -> torch.Tensor:
126-
# Get window index and sequence lengths to rearrange the input tensor.
127-
window_index, cu_seqlens = self._get_window_index(grid_thw)
124+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
125+
# Check if the pixel value size matches with grid size and image config.
126+
assert pixel_values.size() == self.get_pixel_values_size(self.grid_thw)
128127

129128
# Embed the image and rearrange the embedding tensor.
130-
pixel_reshaped = pixel_values.view(self.kernel_size)
129+
pixel_reshaped = pixel_values.reshape(self.kernel_size)
131130
x = self.tok_embedding(pixel_reshaped)
132131
x = x.view(-1, self.config.embedding_dim)
133-
x = self._rearrange(x, window_index).unsqueeze(0)
132+
x = self._rearrange(x, self.window_index).unsqueeze(0)
134133

135-
# Get RoPE and attention mask arranged according to the window index.
136-
cos, sin = self._get_rope(grid_thw)
137-
rope = (
138-
self._rearrange(cos, window_index),
139-
self._rearrange(sin, window_index),
140-
)
134+
rope = self._get_rope(self.grid_thw, self.window_index)
141135

142-
mask = self._get_mask(x.shape[1], cu_seqlens)
136+
mask = self._get_mask(self.grid_thw, self.cu_seqlens)
143137
full_mask = torch.zeros(x.shape[:2])
144138
for i, block in enumerate(self.transformer_blocks):
145139
x = block(
@@ -150,10 +144,42 @@ def forward(
150144

151145
y = self.merger.forward(self.final_norm(x))
152146
# Arrange the output back to the original order.
153-
reverse_index = torch.argsort(window_index)
154-
return y[reverse_index, ...]
155-
156-
def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
147+
return y[self.reverse_index, ...]
148+
149+
def set_image_size(self, image_size: Tuple[int, int]):
150+
"""Set the image size and pre-calculate some values including mask."""
151+
self.config.image_embedding.image_size = image_size
152+
self.grid_thw = self.get_grid_thw()
153+
154+
# Precalculate the window index which can't be lowered to HLO because of
155+
# inconcrete index in:
156+
# index_new = index_padded[index_padded != -100]
157+
self.window_index, self.cu_seqlens = self._get_window_index(self.grid_thw)
158+
159+
# Precalculate the reverse index of window_index until "vhlo.sort_v1" op is
160+
# supported.
161+
self.reverse_index = torch.argsort(self.window_index)
162+
163+
def get_grid_thw(self, num_images: int = 1) -> List[Tuple[int, int, int]]:
164+
"""Calculate the grid size of the input images based on the image config."""
165+
height, width = self.config.image_embedding.image_size
166+
patch_height = height // self.config.image_embedding.patch_size
167+
patch_width = width // self.config.image_embedding.patch_size
168+
# Support only image, i.e. temporal step size is always 1.
169+
return [(1, patch_height, patch_width)] * num_images
170+
171+
def get_pixel_values_size(
172+
self, grid_thw: List[Tuple[int, int, int]]
173+
) -> torch.Size:
174+
"""Calculate the size of pixel values tensor."""
175+
dim_0 = sum(t * h * w for t, h, w in grid_thw)
176+
config = self.config.image_embedding
177+
dim_1 = config.channels * config.temporal_patch_size * config.patch_size**2
178+
return torch.Size((dim_0, dim_1))
179+
180+
def _get_rope(
181+
self, grid_thw: List[Tuple[int, int, int]], window_index: torch.Tensor
182+
) -> Tuple[torch.Tensor, torch.Tensor]:
157183
"""Get RoPE for Qwen VL model based on image grid information.
158184
159185
It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
@@ -182,16 +208,20 @@ def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
182208
wpos_ids = wpos_ids.flatten()
183209
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
184210
pos_ids = torch.cat(pos_ids, dim=0)
185-
max_grid_size = grid_thw[:, 1:].max()
211+
# Assume all the heights and widths are the same for all images.
212+
max_grid_size = max(grid_thw[0][1], grid_thw[0][2])
186213

187214
cos, sin = attention_utils.build_rope_cache(
188215
max_grid_size,
189216
# ROPE parameters for all attn_configs are the same. Take the first one.
190217
self.config.block_config(0).attn_config.head_dim // 2,
191218
)
192-
return cos[pos_ids].flatten(1), sin[pos_ids].flatten(1)
219+
return (
220+
self._rearrange(cos[pos_ids].flatten(1), window_index),
221+
self._rearrange(sin[pos_ids].flatten(1), window_index),
222+
)
193223

194-
def _get_window_index(self, grid_thw: torch.Tensor):
224+
def _get_window_index(self, grid_thw: List[Tuple[int, int, int]]):
195225
"""Get window index for Qwen VL model to rearrange the input tensor.
196226
197227
It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
@@ -207,13 +237,10 @@ def _get_window_index(self, grid_thw: torch.Tensor):
207237
)
208238

209239
for grid_t, grid_h, grid_w in grid_thw:
210-
llm_grid_h, llm_grid_w = (
211-
grid_h // self.config.spatial_merge_size,
212-
grid_w // self.config.spatial_merge_size,
213-
)
214-
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
215-
grid_t, llm_grid_h, llm_grid_w
216-
)
240+
llm_grid_h = grid_h // self.config.spatial_merge_size
241+
llm_grid_w = grid_w // self.config.spatial_merge_size
242+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w)
243+
index = index.reshape((grid_t, llm_grid_h, llm_grid_w))
217244
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
218245
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
219246
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
@@ -236,18 +263,14 @@ def _get_window_index(self, grid_thw: torch.Tensor):
236263
index_padded = index_padded.reshape(-1)
237264
index_new = index_padded[index_padded != -100]
238265
window_index.append(index_new + window_index_id)
239-
spatial_merge_unit = (
240-
self.config.spatial_merge_size * self.config.spatial_merge_size
241-
)
266+
spatial_merge_unit = self.config.spatial_merge_size**2
242267
cu_seqlens_tmp = (
243268
seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
244269
)
245270
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
246-
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
271+
window_index_id += grid_t * llm_grid_h * llm_grid_w
247272

248273
window_index = torch.cat(window_index, dim=0)
249-
cu_window_seqlens = torch.tensor(cu_window_seqlens)
250-
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
251274
return window_index, cu_window_seqlens
252275

253276
def _rearrange(
@@ -258,20 +281,20 @@ def _rearrange(
258281
It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
259282
modified accordingly.
260283
"""
261-
size = x.shape[0]
262-
spatial_merge_unit = (
263-
self.config.spatial_merge_size * self.config.spatial_merge_size
264-
)
265-
x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
284+
spatial_merge_unit = self.config.spatial_merge_size**2
285+
x_reshaped = x.view(x.size(0) // spatial_merge_unit, spatial_merge_unit, -1)
266286
x_rearranged = x_reshaped[window_index, ...]
267-
return x_rearranged.view(size, -1)
287+
return x_rearranged.view(x.shape)
268288

269-
def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
289+
def _get_mask(
290+
self, grid_thw: List[Tuple[int, int, int]], cu_seqlens: List[int]
291+
) -> torch.Tensor:
270292
"""Get attention mask for Qwen VL model.
271293
272294
It's copied from Qwen2_5_VLVisionAttention.forward() and modified
273295
accordingly.
274296
"""
297+
seqlen = self.get_pixel_values_size(grid_thw)[0]
275298
mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
276299
for i in range(1, len(cu_seqlens)):
277300
mask[
@@ -282,15 +305,15 @@ def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
282305
return mask
283306

284307

285-
def get_image_encoder_config() -> QwenVLImageConfig:
308+
def get_image_encoder_config(image_size: Tuple[int, int]) -> QwenVLImageConfig:
286309
"""Returns the model config for the image encoder of a Qwen 2.5 VL model.
287310
288311
Returns:
289312
The model config for the image encoder of a Qwen 2.5 VL model.
290313
"""
291314
image_embedding_config = cfg.ImageEmbeddingConfig(
292315
channels=3,
293-
image_size=0, # Not used in image encoder.
316+
image_size=image_size,
294317
patch_size=14,
295318
temporal_patch_size=2,
296319
)
@@ -336,15 +359,13 @@ def get_image_encoder_config() -> QwenVLImageConfig:
336359
window_size=112,
337360
spatial_merge_size=2,
338361
full_atten_block_indexes=[7, 15, 23, 31],
339-
# TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
340-
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
341-
# enable_hlfb=True,
362+
enable_hlfb=True,
342363
)
343364
return config
344365

345366

346367
def get_fake_image_encoder_config() -> QwenVLImageConfig:
347-
config = get_image_encoder_config()
368+
config = get_image_encoder_config((8, 12))
348369
# PaliGemma image encoder has only one block config.
349370
config.block_config(0).ff_config.intermediate_size = 128
350371
config.image_embedding.patch_size = 2
@@ -353,8 +374,11 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
353374
return config
354375

355376

356-
def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
357-
config = get_image_encoder_config()
377+
def build_image_encoder(
378+
checkpoint_path: str,
379+
image_size: Tuple[int, int] = (34 * 14, 46 * 14),
380+
) -> QwenVLImageEncoder:
381+
config = get_image_encoder_config(image_size)
358382
encoder = QwenVLImageEncoder(config)
359383
load_image_encoder(checkpoint_path, encoder)
360384
encoder.eval()

0 commit comments

Comments
 (0)