Skip to content

Commit b43a14e

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add image encoder and merger of Qwen2.5-VL model.
- Image encoder rearranges input tensors. - Full-stack of Qwen2.5-VL model and its conversion will be in a following CL. PiperOrigin-RevId: 722673107
1 parent 78bbac2 commit b43a14e

File tree

5 files changed

+466
-1
lines changed

5 files changed

+466
-1
lines changed

ai_edge_torch/generative/examples/paligemma/image_encoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self, config: cfg.ModelConfig):
6060
kernel_size=config.image_embedding.patch_size,
6161
stride=config.image_embedding.patch_size,
6262
padding=0,
63+
use_bias=config.embedding_use_bias,
6364
)
6465
num_patches = (
6566
config.image_embedding.image_size // config.image_embedding.patch_size
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of building an image encoder of Qwen 2.5 VL model."""
17+
18+
import dataclasses
19+
from typing import Optional
20+
21+
from ai_edge_torch.generative.layers import attention
22+
from ai_edge_torch.generative.layers import attention_utils
23+
from ai_edge_torch.generative.layers import builder
24+
import ai_edge_torch.generative.layers.model_config as cfg
25+
import ai_edge_torch.generative.utilities.loader as loading_utils
26+
import torch
27+
from torch import nn
28+
import torch.nn.functional as F
29+
30+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31+
ff_up_proj="visual.blocks.{}.mlp.up_proj",
32+
ff_down_proj="visual.blocks.{}.mlp.down_proj",
33+
ff_gate_proj="visual.blocks.{}.mlp.gate_proj",
34+
attn_fused_qkv_proj="visual.blocks.{}.attn.qkv",
35+
attn_output_proj="visual.blocks.{}.attn.proj",
36+
pre_attn_norm="visual.blocks.{}.norm1",
37+
post_attn_norm="visual.blocks.{}.norm2",
38+
embedding="visual.patch_embed.proj",
39+
final_norm="visual.merger.ln_q",
40+
)
41+
42+
MERGER_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43+
ff_up_proj="visual.merger.mlp.0",
44+
ff_down_proj="visual.merger.mlp.2",
45+
)
46+
47+
48+
@dataclasses.dataclass
49+
class QwenVLMergerConfig:
50+
"""Merger parameters."""
51+
52+
activation: cfg.ActivationConfig
53+
intermediate_size: int
54+
out_embedding_dim: int
55+
use_bias: bool = False
56+
57+
58+
@dataclasses.dataclass
59+
class QwenVLImageConfig(cfg.ModelConfig):
60+
"""model config for Qwen 2.5 VL model."""
61+
62+
merger_config: Optional[QwenVLMergerConfig] = None
63+
window_size: Optional[int] = None
64+
spatial_merge_size: Optional[int] = None
65+
full_atten_block_indexes: Optional[list[int]] = None
66+
67+
68+
class QwenVLMerger(nn.Module):
69+
"""Merger of Qwen 2.5 VL models from the Edge Generative API.
70+
71+
It's based on Qwen2_5_VLPatchMerger.
72+
"""
73+
74+
def __init__(self, config: QwenVLImageConfig):
75+
super().__init__()
76+
self.intermediate_size = config.merger_config.intermediate_size
77+
self.w1 = nn.Linear(self.intermediate_size, self.intermediate_size)
78+
self.act = builder.get_activation(config.merger_config.activation)
79+
self.w2 = nn.Linear(
80+
self.intermediate_size, config.merger_config.out_embedding_dim
81+
)
82+
83+
def forward(self, x: torch.Tensor) -> torch.Tensor:
84+
x_reshaped = x.view(-1, self.intermediate_size)
85+
return self.w2(self.act(self.w1(x_reshaped)))
86+
87+
88+
class QwenVLImageEncoder(nn.Module):
89+
"""Image encoder of Qwen 2.5 VL models from the Edge Generative API."""
90+
91+
def __init__(self, config: QwenVLImageConfig):
92+
super().__init__()
93+
94+
# Tensor shape used to reshape pixel_values in forward() and various places.
95+
self.kernel_size = (
96+
-1, # batch size
97+
config.image_embedding.channels,
98+
config.image_embedding.temporal_patch_size,
99+
config.image_embedding.patch_size,
100+
config.image_embedding.patch_size,
101+
)
102+
self.tok_embedding = nn.Conv3d(
103+
in_channels=self.kernel_size[1],
104+
out_channels=config.embedding_dim,
105+
kernel_size=self.kernel_size[2:],
106+
stride=self.kernel_size[2:],
107+
padding=0,
108+
bias=config.embedding_use_bias,
109+
)
110+
111+
self.transformer_blocks = nn.ModuleList(
112+
attention.TransformerBlock(config.block_config(idx), config)
113+
for idx in range(config.num_layers)
114+
)
115+
self.final_norm = builder.build_norm(
116+
config.embedding_dim,
117+
config.final_norm_config,
118+
)
119+
self.merger = QwenVLMerger(config)
120+
self.config = config
121+
122+
@torch.inference_mode
123+
def forward(
124+
self, pixel_values: torch.Tensor, grid_thw: torch.Tensor
125+
) -> torch.Tensor:
126+
# Get window index and sequence lengths to rearrange the input tensor.
127+
window_index, cu_seqlens = self._get_window_index(grid_thw)
128+
129+
# Embed the image and rearrange the embedding tensor.
130+
pixel_reshaped = pixel_values.view(self.kernel_size)
131+
x = self.tok_embedding(pixel_reshaped)
132+
x = x.view(-1, self.config.embedding_dim)
133+
x = self._rearrange(x, window_index).unsqueeze(0)
134+
135+
# Get RoPE and attention mask arranged according to the window index.
136+
cos, sin = self._get_rope(grid_thw)
137+
rope = (
138+
self._rearrange(cos, window_index),
139+
self._rearrange(sin, window_index),
140+
)
141+
142+
mask = self._get_mask(x.shape[1], cu_seqlens)
143+
full_mask = torch.zeros(x.shape[:2])
144+
for i, block in enumerate(self.transformer_blocks):
145+
x = block(
146+
x,
147+
rope=rope,
148+
mask=full_mask if i in self.config.full_atten_block_indexes else mask,
149+
)
150+
151+
y = self.merger.forward(self.final_norm(x))
152+
# Arrange the output back to the original order.
153+
reverse_index = torch.argsort(window_index)
154+
return y[reverse_index, ...]
155+
156+
def _get_rope(self, grid_thw: torch.Tensor) -> torch.Tensor:
157+
"""Get RoPE for Qwen VL model based on image grid information.
158+
159+
It's copied from Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb() and
160+
modified accordingly.
161+
"""
162+
pos_ids = []
163+
for t, h, w in grid_thw:
164+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
165+
hpos_ids = hpos_ids.reshape(
166+
h // self.config.spatial_merge_size,
167+
self.config.spatial_merge_size,
168+
w // self.config.spatial_merge_size,
169+
self.config.spatial_merge_size,
170+
)
171+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
172+
hpos_ids = hpos_ids.flatten()
173+
174+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
175+
wpos_ids = wpos_ids.reshape(
176+
h // self.config.spatial_merge_size,
177+
self.config.spatial_merge_size,
178+
w // self.config.spatial_merge_size,
179+
self.config.spatial_merge_size,
180+
)
181+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
182+
wpos_ids = wpos_ids.flatten()
183+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
184+
pos_ids = torch.cat(pos_ids, dim=0)
185+
max_grid_size = grid_thw[:, 1:].max()
186+
187+
cos, sin = attention_utils.build_rope_cache(
188+
max_grid_size,
189+
# ROPE parameters for all attn_configs are the same. Take the first one.
190+
self.config.block_config(0).attn_config.head_dim // 2,
191+
)
192+
return cos[pos_ids].flatten(1), sin[pos_ids].flatten(1)
193+
194+
def _get_window_index(self, grid_thw: torch.Tensor):
195+
"""Get window index for Qwen VL model to rearrange the input tensor.
196+
197+
It's copied from Qwen2_5_VisionTransformerPretrainedModel.get_window_index()
198+
and modified accordingly.
199+
"""
200+
window_index: list = []
201+
cu_window_seqlens: list = [0]
202+
window_index_id = 0
203+
vit_merger_window_size = (
204+
self.config.window_size
205+
// self.config.spatial_merge_size
206+
// self.config.image_embedding.patch_size
207+
)
208+
209+
for grid_t, grid_h, grid_w in grid_thw:
210+
llm_grid_h, llm_grid_w = (
211+
grid_h // self.config.spatial_merge_size,
212+
grid_w // self.config.spatial_merge_size,
213+
)
214+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
215+
grid_t, llm_grid_h, llm_grid_w
216+
)
217+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
218+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
219+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
220+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
221+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
222+
index_padded = index_padded.reshape(
223+
grid_t,
224+
num_windows_h,
225+
vit_merger_window_size,
226+
num_windows_w,
227+
vit_merger_window_size,
228+
)
229+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
230+
grid_t,
231+
num_windows_h * num_windows_w,
232+
vit_merger_window_size,
233+
vit_merger_window_size,
234+
)
235+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
236+
index_padded = index_padded.reshape(-1)
237+
index_new = index_padded[index_padded != -100]
238+
window_index.append(index_new + window_index_id)
239+
spatial_merge_unit = (
240+
self.config.spatial_merge_size * self.config.spatial_merge_size
241+
)
242+
cu_seqlens_tmp = (
243+
seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
244+
)
245+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
246+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
247+
248+
window_index = torch.cat(window_index, dim=0)
249+
cu_window_seqlens = torch.tensor(cu_window_seqlens)
250+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
251+
return window_index, cu_window_seqlens
252+
253+
def _rearrange(
254+
self, x: torch.Tensor, window_index: torch.Tensor
255+
) -> torch.Tensor:
256+
"""Rearrange the tensor according to window_index.
257+
258+
It's copied from Qwen2_5_VisionTransformerPretrainedModel.forward() and
259+
modified accordingly.
260+
"""
261+
size = x.shape[0]
262+
spatial_merge_unit = (
263+
self.config.spatial_merge_size * self.config.spatial_merge_size
264+
)
265+
x_reshaped = x.view(size // spatial_merge_unit, spatial_merge_unit, -1)
266+
x_rearranged = x_reshaped[window_index, ...]
267+
return x_rearranged.view(size, -1)
268+
269+
def _get_mask(self, seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
270+
"""Get attention mask for Qwen VL model.
271+
272+
It's copied from Qwen2_5_VLVisionAttention.forward() and modified
273+
accordingly.
274+
"""
275+
mask = torch.full([1, 1, seqlen, seqlen], float("-inf"))
276+
for i in range(1, len(cu_seqlens)):
277+
mask[
278+
...,
279+
cu_seqlens[i - 1] : cu_seqlens[i],
280+
cu_seqlens[i - 1] : cu_seqlens[i],
281+
] = 0
282+
return mask
283+
284+
285+
def get_image_encoder_config() -> QwenVLImageConfig:
286+
"""Returns the model config for the image encoder of a Qwen 2.5 VL model.
287+
288+
Returns:
289+
The model config for the image encoder of a Qwen 2.5 VL model.
290+
"""
291+
image_embedding_config = cfg.ImageEmbeddingConfig(
292+
channels=3,
293+
image_size=0, # Not used in image encoder.
294+
patch_size=14,
295+
temporal_patch_size=2,
296+
)
297+
attn_config = cfg.AttentionConfig(
298+
num_heads=16,
299+
head_dim=80,
300+
num_query_groups=16,
301+
qkv_transpose_before_split=True,
302+
qkv_use_bias=True,
303+
output_proj_use_bias=True,
304+
)
305+
ff_config = cfg.FeedForwardConfig(
306+
type=cfg.FeedForwardType.GATED,
307+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
308+
intermediate_size=3420,
309+
use_bias=True,
310+
)
311+
norm_config = cfg.NormalizationConfig(
312+
type=cfg.NormalizationType.RMS_NORM,
313+
epsilon=1e-6,
314+
)
315+
block_config = cfg.TransformerBlockConfig(
316+
attn_config=attn_config,
317+
ff_config=ff_config,
318+
pre_attention_norm_config=norm_config,
319+
post_attention_norm_config=norm_config,
320+
)
321+
merger_config = QwenVLMergerConfig(
322+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU),
323+
intermediate_size=5120, # embedding_dim(1280) * spatial_merge_size(2)^2
324+
out_embedding_dim=2048, # embedding_dim of decoder config.
325+
use_bias=True,
326+
)
327+
config = QwenVLImageConfig(
328+
vocab_size=0, # Not used in image encoder.
329+
num_layers=32,
330+
max_seq_len=0, # Not used in image encoder.
331+
embedding_dim=1280,
332+
image_embedding=image_embedding_config,
333+
block_configs=block_config,
334+
final_norm_config=norm_config,
335+
merger_config=merger_config,
336+
window_size=112,
337+
spatial_merge_size=2,
338+
full_atten_block_indexes=[7, 15, 23, 31],
339+
# TODO: b/377051577 - Once RemoveSDPACompositeZeroMaskPass is removed,
340+
# enable_hlfb can be set to True. See b/383865404#comment3 for details.
341+
# enable_hlfb=True,
342+
)
343+
return config
344+
345+
346+
def get_fake_image_encoder_config() -> QwenVLImageConfig:
347+
config = get_image_encoder_config()
348+
# PaliGemma image encoder has only one block config.
349+
config.block_config(0).ff_config.intermediate_size = 128
350+
config.image_embedding.patch_size = 2
351+
config.num_layers = 2
352+
config.merger_config.intermediate_size = 128
353+
return config
354+
355+
356+
def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
357+
config = get_image_encoder_config()
358+
encoder = QwenVLImageEncoder(config)
359+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
360+
# Loose the strictness because only image encoder is being loaded.
361+
loader.load(encoder, strict=False)
362+
363+
# Load merger weights.
364+
merger_loader = loading_utils.ModelLoader(checkpoint_path, None)
365+
state = merger_loader.get_state()
366+
w1_state = dict()
367+
w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
368+
if config.merger_config.use_bias:
369+
w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
370+
encoder.merger.w1.load_state_dict(w1_state)
371+
372+
w2_state = dict()
373+
w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
374+
if config.merger_config.use_bias:
375+
w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
376+
encoder.merger.w2.load_state_dict(w2_state)
377+
378+
encoder.eval()
379+
return encoder

0 commit comments

Comments
 (0)