Skip to content

support ernie_vl #4763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class MLLMModelType:
step_audio = 'step_audio'
kimi_vl = 'kimi_vl'
keye_vl = 'keye_vl'
ernie_vl = 'ernie_vl'

phi3_vision = 'phi3_vision'
phi4_multimodal = 'phi4_multimodal'
Expand Down
32 changes: 30 additions & 2 deletions swift/llm/model/model/baidu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import TemplateType
from swift.utils import get_logger
from ..constant import LLMModelType
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
from ..constant import LLMModelType, MLLMModelType
from ..model_arch import ModelArch
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
get_model_tokenizer_with_flash_attn, register_model)

logger = get_logger()

Expand All @@ -25,3 +27,29 @@
get_model_tokenizer_with_flash_attn,
architectures=['Ernie4_5_ForCausalLM', 'Ernie4_5_MoeForCausalLM'],
))


def get_model_tokenizer_ernie_vl(*args, **kwargs):
model, processor = get_model_tokenizer_multimodal(*args, **kwargs)
if model is not None:
model.add_image_preprocess(processor)
return model, processor


register_model(
ModelMeta(
MLLMModelType.ernie_vl,
[
ModelGroup([
Model('PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Base-PT', 'baidu/ERNIE-4.5-VL-28B-A3B-Base-PT'),
Model('PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT', 'baidu/ERNIE-4.5-VL-28B-A3B-PT'),
Model('PaddlePaddle/ERNIE-4.5-VL-424B-A47B-Base-PT', 'baidu/ERNIE-4.5-VL-424B-A47B-Base-PT'),
Model('PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT', 'baidu/ERNIE-4.5-VL-424B-A47B-PT'),
]),
],
TemplateType.ernie_vl,
get_model_tokenizer_ernie_vl,
model_arch=ModelArch.ernie_vl,
architectures=['Ernie4_5_VLMoeForConditionalGeneration'],
requires=['transformers>=4.52'],
))
8 changes: 8 additions & 0 deletions swift/llm/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class MLLMModelArch:
idefics3 = 'idefics3'

got_ocr2 = 'got_ocr2'
ernie_vl = 'ernie_vl'

ovis1_6 = 'ovis1_6'
molmo = 'molmo'
Expand Down Expand Up @@ -546,6 +547,13 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
aligner='model.mm_projector_vary',
vision_tower='model.vision_tower_high',
))
register_model_arch(
MultiModelKeys(
MLLMModelArch.ernie_vl,
language_model='model',
aligner='model.resampler_model',
vision_tower='vision_model',
))

if transformers_ge_4_52:
register_model_arch(
Expand Down
6 changes: 4 additions & 2 deletions swift/llm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]]) -> Option
INF = int(1e9)
max_model_len = INF

possible_keys = [
possible_keys = {
'seq_length', # qwen, chatglm
'max_position_embeddings', # qwen1.5, llama2
'n_positions', # polylm, phi-2
Expand All @@ -160,7 +160,9 @@ def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]]) -> Option
'max_seq_len',
'max_sequence_length',
'max_seq_length',
]
}
if getattr(config, 'model_type', None) == 'ernie4_5_moe_vl':
possible_keys.discard('max_sequence_length')
for key in possible_keys:
max_len_key = HfConfigFactory.get_config_attr(config, key)
if max_len_key is not None:
Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class MLLMTemplateType:
step_audio = 'step_audio'
kimi_vl = 'kimi_vl'
keye_vl = 'keye_vl'
ernie_vl = 'ernie_vl'

idefics3 = 'idefics3'
pixtral = 'pixtral'
Expand Down
49 changes: 45 additions & 4 deletions swift/llm/template/template/baidu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
from typing import Optional
from typing import Any, Dict, List, Literal, Optional

from ..constant import LLMTemplateType
from ..base import Template
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import TemplateMeta, register_template
from ..utils import Prompt

from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, align_image_inputs, findall, split_tokens
import torch.nn.functional as F

@dataclass
class ERNIETemplateMeta(TemplateMeta):
Expand All @@ -17,3 +19,42 @@ class ERNIETemplateMeta(TemplateMeta):


register_template(ERNIETemplateMeta(LLMTemplateType.ernie))


class ERNIETemplate(Template):
placeholder_tokens = ['<|IMAGE_PLACEHOLDER|>']

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
assert media_type == 'image'
return [f'Picture {index+1}:<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>']

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = super()._encode(inputs)
input_ids = encoded['input_ids']
labels = encoded['labels']
if inputs.images:
text = self.processor.decode(input_ids).replace('<|IMAGE_PLACEHOLDER|>', '<|image@placeholder|>')
new_inputs = self.processor(
text=[text],
images=inputs.images,
videos=inputs.videos,
padding=True,
return_tensors='pt',
)
encoded.update(new_inputs)
new_input_ids = new_inputs['input_ids'][0].tolist()
encoded['input_ids'], encoded['labels'] = align_image_inputs(input_ids, labels, new_input_ids,
self.placeholder_tokens[0])
encoded['position_ids'] = encoded['position_ids'][0]
return encoded

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
res = {}
for key in ['images', 'grid_thw', 'image_type_ids']:
res[key] = self.concat_tensor(batch, key, 0)
res.update(super()._data_collator(batch, padding_to=padding_to))
return res


register_template(ERNIETemplateMeta(MLLMTemplateType.ernie_vl, template_cls=ERNIETemplate))
12 changes: 12 additions & 0 deletions swift/llm/template/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def fetch_one(element: Union[Tuple, List, Set, Dict, Any], item_type: Optional[T
return element


def split_tokens(tokens: List[int], sub_tokens: Union[int, List[int]]) -> List[List[int]]:
split_idx = findall(tokens, sub_tokens)
split_idx = [-1, *split_idx, len(tokens)]
res = []
for i in range(len(split_idx) - 1):
idx, idx_next = split_idx[i] + 1, split_idx[i + 1]
res.append(tokens[idx:idx_next])
return res


def findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]:
"""Find the index of a token in the token_list."""
if isinstance(sub_token_list, int):
Expand All @@ -82,6 +92,8 @@ def align_image_inputs(input_ids: List[int], labels: List[int], new_input_ids,
if isinstance(new_input_ids, torch.Tensor):
new_input_ids = new_input_ids.tolist()

if labels is None:
return new_input_ids, None
# Find the tokens after the image_token in input_ids, and then align them.
i, j = 0, 0
while i < len(input_ids):
Expand Down
15 changes: 14 additions & 1 deletion tests/test_align/test_template/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,18 @@ def test_keye_vl():
assert response == response2


def test_ernie_vl():
pt_engine = PtEngine('PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT')
messages = [{'role': 'user', 'content': '<image><image>What is the difference between the two images?'}]
images = [
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png',
'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png'
]
response = _infer_model(pt_engine, messages=messages, images=images)
pt_engine.default_template.template_backend = 'jinja'
response2 = _infer_model(pt_engine, messages=messages, images=images)
assert response == response2

if __name__ == '__main__':
from swift.llm import PtEngine, RequestConfig
from swift.utils import get_logger, seed_everything
Expand Down Expand Up @@ -660,4 +672,5 @@ def test_keye_vl():
# test_kimi_vl_thinking()
# test_glm4_1v()
# test_gemma3n()
test_keye_vl()
# test_keye_vl()
test_ernie_vl()
Loading