Skip to content

Commit 48d225a

Browse files
ai-edge-botcopybara-github
authored andcommitted
Full stack of Qwen2.5-VL model.
- Handles multimodal RoPE - Supports only one image put at the beginning of the input embeds - Though it's not fully compatible to the original model, it generates reasonable outputs - Conversion will be in a following CL PiperOrigin-RevId: 724076611
1 parent 729690b commit 48d225a

File tree

5 files changed

+412
-10
lines changed

5 files changed

+412
-10
lines changed

ai_edge_torch/generative/examples/qwen_vl/decoder.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,61 @@
1515

1616
"""Example of building decoder for Qwen 2.5 VL models."""
1717

18+
from typing import Optional, Tuple
19+
20+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
1821
import ai_edge_torch.generative.layers.model_config as cfg
1922
from ai_edge_torch.generative.utilities import model_builder
20-
from torch import nn
23+
import torch
2124

2225
TENSOR_NAMES = model_builder.TENSOR_NAMES
2326

2427

2528
class Decoder(model_builder.DecoderOnlyModel):
26-
"""A decoder for Qwen-VL model built from the Edge Generative API layers."""
27-
pass
29+
"""A decoder for Qwen-VL model built from the Edge Generative API layers.
30+
31+
Besides a tensor of text token IDs, forward() can also take a tensor of
32+
embeddings which may include text or image or both.
33+
"""
34+
35+
@torch.inference_mode
36+
def forward(
37+
self,
38+
tokens: torch.Tensor,
39+
input_pos: torch.Tensor,
40+
kv_cache: kv_utils.KVCache,
41+
input_embeds: torch.Tensor = None,
42+
rope: Tuple[torch.Tensor, torch.Tensor] = None,
43+
mask: Optional[torch.Tensor] = None,
44+
export_config: Optional[model_builder.ExportConfig] = None,
45+
) -> dict[torch.Tensor, kv_utils.KVCache]:
46+
if input_embeds is None:
47+
_, seq_len = tokens.size()
48+
assert self.config.max_seq_len >= seq_len, (
49+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
50+
f" {self.config.max_seq_len}"
51+
)
52+
# token embeddings of shape (b, t, n_embd)
53+
input_embeds = self.tok_embedding(tokens)
54+
55+
if rope is None:
56+
# ROPE parameters for all attn_configs are the same. Take the first one.
57+
attn_config = self.config.block_config(0).attn_config
58+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
59+
rope = self.config.build_rope(input_pos, n_elem, attn_config.rotary_base)
60+
61+
if mask is None:
62+
mask = self.mask_cache.index_select(2, input_pos)
63+
mask = mask[:, :, :, : self.config.kv_cache_max]
64+
65+
return self._forward_with_embeds(
66+
input_embeds,
67+
rope,
68+
mask,
69+
input_pos,
70+
kv_cache,
71+
export_config=export_config,
72+
)
2873

2974

3075
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -82,7 +127,7 @@ def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
82127
return config
83128

84129

85-
def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
130+
def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
86131
return model_builder.build_decoder_only_model(
87132
checkpoint_path=checkpoint_path,
88133
config=get_decoder_config(**kwargs),

ai_edge_torch/generative/examples/qwen_vl/image_encoder.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,12 @@ def get_fake_image_encoder_config() -> QwenVLImageConfig:
356356
def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
357357
config = get_image_encoder_config()
358358
encoder = QwenVLImageEncoder(config)
359+
load_image_encoder(checkpoint_path, encoder)
360+
encoder.eval()
361+
return encoder
362+
363+
364+
def load_image_encoder(checkpoint_path: str, encoder: QwenVLImageEncoder):
359365
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
360366
# Loose the strictness because only image encoder is being loaded.
361367
loader.load(encoder, strict=False)
@@ -365,15 +371,12 @@ def build_image_encoder(checkpoint_path: str) -> QwenVLImageEncoder:
365371
state = merger_loader.get_state()
366372
w1_state = dict()
367373
w1_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.weight")
368-
if config.merger_config.use_bias:
374+
if encoder.config.merger_config.use_bias:
369375
w1_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_up_proj}.bias")
370376
encoder.merger.w1.load_state_dict(w1_state)
371377

372378
w2_state = dict()
373379
w2_state["weight"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.weight")
374-
if config.merger_config.use_bias:
380+
if encoder.config.merger_config.use_bias:
375381
w2_state["bias"] = state.pop(f"{MERGER_TENSOR_NAMES.ff_down_proj}.bias")
376382
encoder.merger.w2.load_state_dict(w2_state)
377-
378-
encoder.eval()
379-
return encoder
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of building a full-stack of Qwen 2.5 VL model."""
17+
18+
import dataclasses
19+
from typing import List, Optional, Tuple
20+
21+
from ai_edge_torch.generative.examples.qwen_vl import decoder
22+
from ai_edge_torch.generative.examples.qwen_vl import image_encoder
23+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
24+
import ai_edge_torch.generative.layers.model_config as cfg
25+
from ai_edge_torch.generative.utilities import model_builder
26+
import ai_edge_torch.generative.utilities.loader as loading_utils
27+
import torch
28+
from torch import nn
29+
30+
31+
@dataclasses.dataclass
32+
class QwenVLConfig:
33+
"""Qwen VL model configurations."""
34+
35+
image_encoder_config: image_encoder.QwenVLImageConfig
36+
decoder_config: cfg.ModelConfig
37+
image_token_id: int
38+
mrope_section: List[int]
39+
40+
41+
class QwenVL(nn.Module):
42+
"""Qwen VL model from the Edge Generative API."""
43+
44+
def __init__(self, config: QwenVLConfig):
45+
super().__init__()
46+
47+
self.image_encoder = image_encoder.QwenVLImageEncoder(
48+
config.image_encoder_config
49+
)
50+
self.decoder = decoder.Decoder(config.decoder_config)
51+
# The amount of adjustment in input_pos to calculate RoPE properly in
52+
# forward() calls after image is handled.
53+
self.rope_pos_adjust = 0
54+
self.config = config
55+
56+
@torch.inference_mode
57+
def forward(
58+
self,
59+
tokens: torch.Tensor,
60+
input_pos: torch.Tensor,
61+
kv_cache: kv_utils.KVCache,
62+
mask: Optional[torch.Tensor] = None,
63+
pixel_values: torch.Tensor = None,
64+
grid_thw: torch.Tensor = None,
65+
export_config: Optional[model_builder.ExportConfig] = None,
66+
) -> dict[torch.Tensor, kv_utils.KVCache]:
67+
if pixel_values is None:
68+
return self.decoder(
69+
tokens=tokens,
70+
input_pos=input_pos,
71+
kv_cache=kv_cache,
72+
mask=mask,
73+
rope=self._build_text_rope(input_pos),
74+
input_embeds=None,
75+
export_config=export_config,
76+
)
77+
78+
input_embeds = self.decoder.tok_embedding(tokens)
79+
image_embeds = self.image_encoder(pixel_values, grid_thw).unsqueeze(0)
80+
81+
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
82+
# can be done like:
83+
#
84+
# image_mask = tokens == self.config.image_token_id
85+
# image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
86+
# input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
87+
#
88+
# Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
89+
# Assume that image is put at the beginning of the input sequence wrapped
90+
# with vision_start and vision_end tokens.
91+
input_embeds = torch.cat(
92+
(
93+
input_embeds[:, :1, :],
94+
image_embeds,
95+
input_embeds[:, image_embeds.shape[1] + 1 :, :],
96+
),
97+
dim=1,
98+
)
99+
100+
return self.decoder(
101+
tokens=None,
102+
input_pos=input_pos,
103+
kv_cache=kv_cache,
104+
mask=mask,
105+
input_embeds=input_embeds,
106+
rope=self._build_multimodal_rope(input_pos, grid_thw),
107+
export_config=export_config,
108+
)
109+
110+
def _build_rope(
111+
self, rope_pos: torch.Tensor
112+
) -> Tuple[torch.Tensor, torch.Tensor]:
113+
# ROPE parameters for all attn_configs are the same. Take the first one.
114+
attn_config = self.config.decoder_config.block_config(0).attn_config
115+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
116+
return self.config.decoder_config.build_rope(
117+
rope_pos, n_elem, attn_config.rotary_base
118+
)
119+
120+
def _build_text_rope(
121+
self, input_pos: torch.Tensor
122+
) -> Tuple[torch.Tensor, torch.Tensor]:
123+
# Reset rope_pos_adjust to 0 when input sequence starts from scratch, i.e.
124+
# input_pos[0] = 0.
125+
if input_pos[0] == 0:
126+
self.rope_pos_adjust = 0
127+
return self._build_rope(input_pos + self.rope_pos_adjust)
128+
129+
def _build_multimodal_rope(
130+
self, input_pos: torch.Tensor, grid_thw: torch.Tensor
131+
) -> Tuple[torch.Tensor, torch.Tensor]:
132+
"""Builds RoPE of multimodal input for the Qwen VL model.
133+
134+
It's copied from Qwen2_5_VLForConditionalGeneration.get_rope_index() and
135+
simplified based on the assumption that an image is put at the beginning of
136+
the input sequence with vision start and vision end tokens.
137+
"""
138+
spatial_merge_size = self.config.image_encoder_config.spatial_merge_size
139+
height = grid_thw[0][1] // spatial_merge_size
140+
width = grid_thw[0][2] // spatial_merge_size
141+
image_pos_max = max(height, width)
142+
image_pos_count = height * width
143+
144+
# The position of vision end tokek and text tokens and after the image.
145+
text_pos_start = image_pos_max + 1
146+
text_pos_count = len(input_pos) - image_pos_count - 1
147+
text_pos = torch.arange(text_pos_start, text_pos_start + text_pos_count)
148+
# Set input_pos_adjust since text_pos_start has changed.
149+
self.rope_pos_adjust = image_pos_max - image_pos_count
150+
151+
temporal_rope = self._build_image_text_rope(
152+
torch.ones(image_pos_count, dtype=torch.int), text_pos
153+
)
154+
height_rope = self._build_image_text_rope(
155+
torch.arange(1, height + 1).view(-1, 1).expand(-1, width).flatten(),
156+
text_pos,
157+
)
158+
width_rope = self._build_image_text_rope(
159+
torch.arange(1, width + 1).view(1, -1).expand(height, -1).flatten(),
160+
text_pos,
161+
)
162+
163+
return (
164+
self._merge_ropes(temporal_rope[0], height_rope[0], width_rope[0]),
165+
self._merge_ropes(temporal_rope[1], height_rope[1], width_rope[1]),
166+
)
167+
168+
def _build_image_text_rope(
169+
self, image_pos: torch.Tensor, text_pos: torch.Tensor
170+
) -> Tuple[torch.Tensor, torch.Tensor]:
171+
return self._build_rope(
172+
torch.cat((torch.zeros(1, dtype=torch.int), image_pos, text_pos))
173+
)
174+
175+
def _merge_ropes(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
176+
"""Merges RoPE tensors based on apply_multimodal_rotary_pos_emb()."""
177+
split = torch.stack([a, b, c]).split(self.config.mrope_section, dim=-1)
178+
return torch.cat([m[i % 3] for i, m in enumerate(split)], dim=-1)
179+
180+
181+
def get_model_config(**kwargs) -> QwenVLConfig:
182+
"""Returns the model config for a PaliGemma 3B-224 model.
183+
184+
Returns:
185+
The model config for a PaliGemma 3B model.
186+
"""
187+
return QwenVLConfig(
188+
image_encoder_config=image_encoder.get_image_encoder_config(),
189+
decoder_config=decoder.get_decoder_config(**kwargs),
190+
image_token_id=151655,
191+
mrope_section=[16, 24, 24],
192+
)
193+
194+
195+
def get_fake_model_config(**kwargs) -> QwenVLConfig:
196+
return QwenVLConfig(
197+
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
198+
decoder_config=decoder.get_fake_decoder_config(**kwargs),
199+
image_token_id=127,
200+
)
201+
202+
203+
def build_model(checkpoint_path: str, **kwargs) -> QwenVL:
204+
config = get_model_config(**kwargs)
205+
model = QwenVL(config)
206+
image_encoder.load_image_encoder(checkpoint_path, model.image_encoder)
207+
# Load the parameters of decoder.
208+
loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
209+
loader.load(model.decoder, strict=False)
210+
model.eval()
211+
return model

0 commit comments

Comments
 (0)