Skip to content

Commit 78bbac2

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add decoder of Qwen2.5-VL model.
- Image encoder and full Qwen2.5-VL model will be added in following CLs - Decoder outputs last hidden states instead of logits PiperOrigin-RevId: 721541592
1 parent f87e4fa commit 78bbac2

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of building decoder for Qwen 2.5 VL models."""
17+
18+
import ai_edge_torch.generative.layers.model_config as cfg
19+
from ai_edge_torch.generative.utilities import model_builder
20+
from torch import nn
21+
22+
TENSOR_NAMES = model_builder.TENSOR_NAMES
23+
24+
25+
class Decoder(model_builder.DecoderOnlyModel):
26+
"""A decoder for Qwen-VL model built from the Edge Generative API layers."""
27+
pass
28+
29+
30+
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31+
"""Returns the model config for a Qwen 2.5 VL 3B model.
32+
33+
Args:
34+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
35+
is 1024.
36+
37+
Returns:
38+
The model config for a Qwen 2.5 VL 3B model.
39+
"""
40+
attn_config = cfg.AttentionConfig(
41+
num_heads=16,
42+
head_dim=128,
43+
num_query_groups=2,
44+
rotary_base=1000000,
45+
rotary_percentage=1.0,
46+
qkv_use_bias=True,
47+
)
48+
ff_config = cfg.FeedForwardConfig(
49+
type=cfg.FeedForwardType.GATED,
50+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
51+
intermediate_size=11008,
52+
)
53+
norm_config = cfg.NormalizationConfig(
54+
type=cfg.NormalizationType.RMS_NORM,
55+
epsilon=1e-06,
56+
)
57+
block_config = cfg.TransformerBlockConfig(
58+
attn_config=attn_config,
59+
ff_config=ff_config,
60+
pre_attention_norm_config=norm_config,
61+
post_attention_norm_config=norm_config,
62+
)
63+
config = cfg.ModelConfig(
64+
vocab_size=151936,
65+
num_layers=36,
66+
max_seq_len=32768,
67+
embedding_dim=2048,
68+
kv_cache_max_len=kv_cache_max_len,
69+
block_configs=block_config,
70+
final_norm_config=norm_config,
71+
enable_hlfb=True,
72+
)
73+
return config
74+
75+
76+
def get_fake_decoder_config(**kwargs) -> cfg.ModelConfig:
77+
config = get_decoder_config(**kwargs)
78+
config.vocab_size = 128
79+
config.num_layers = 2
80+
# Decoder has only one block config.
81+
config.block_config(0).ff_config.intermediate_size = 64
82+
return config
83+
84+
85+
def build_decoder(checkpoint_path: str, **kwargs) -> nn.Module:
86+
return model_builder.build_decoder_only_model(
87+
checkpoint_path=checkpoint_path,
88+
config=get_decoder_config(**kwargs),
89+
tensor_names=TENSOR_NAMES,
90+
model_class=Decoder,
91+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Verifies the reauthored decoder of Qwen 2.5 VL 3B models."""
17+
18+
import logging
19+
import pathlib
20+
21+
from absl import app
22+
from ai_edge_torch.generative.examples.qwen_vl import decoder
23+
from ai_edge_torch.generative.utilities import verifier
24+
import torch
25+
import transformers
26+
27+
28+
class DecoderWrapper(verifier.ModelWrapper):
29+
"""Wraps the decoder of Qwen 2.5 VL models for verification."""
30+
31+
def __init__(self, model: torch.nn.Module, lm_head: torch.nn.Module):
32+
super().__init__(model)
33+
self.lm_head = lm_head
34+
35+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36+
output = self.model.forward(tokens)
37+
return self.lm_head(output["last_hidden_state"])
38+
39+
40+
def main(_):
41+
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
42+
logging.info("Loading the original model from: %s", checkpoint)
43+
original_model = (
44+
transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained(
45+
checkpoint
46+
)
47+
)
48+
49+
# Locate the cached dir.
50+
cached_config_file = transformers.utils.cached_file(
51+
checkpoint, transformers.utils.CONFIG_NAME
52+
)
53+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
54+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
55+
reauthored_model = decoder.build_decoder(reauthored_checkpoint)
56+
57+
# Verify the reauthored model only with input IDs because the original decoder
58+
# does not support generate() with prompts.
59+
input_ids = [1, 2, 3, 4]
60+
try:
61+
verifier.verify_with_input_ids(
62+
original_model=DecoderWrapper(
63+
original_model.model,
64+
original_model.lm_head,
65+
),
66+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
67+
input_ids=input_ids,
68+
atol=1e-04,
69+
)
70+
except AssertionError as e:
71+
logging.error("*** FAILED *** verify with input IDs: %s", e)
72+
else:
73+
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
74+
75+
76+
if __name__ == "__main__":
77+
app.run(main)

0 commit comments

Comments
 (0)