diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index a951c777d3..210b3f021b 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -221,6 +221,7 @@ class MLLMModelType: step_audio = 'step_audio' kimi_vl = 'kimi_vl' keye_vl = 'keye_vl' + ernie_vl = 'ernie_vl' phi3_vision = 'phi3_vision' phi4_multimodal = 'phi4_multimodal' diff --git a/swift/llm/model/model/baidu.py b/swift/llm/model/model/baidu.py index a472d24bc0..0cad2eb10a 100644 --- a/swift/llm/model/model/baidu.py +++ b/swift/llm/model/model/baidu.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from swift.llm import TemplateType from swift.utils import get_logger -from ..constant import LLMModelType -from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model +from ..constant import LLMModelType, MLLMModelType +from ..model_arch import ModelArch +from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, + get_model_tokenizer_with_flash_attn, register_model) logger = get_logger() @@ -25,3 +27,29 @@ get_model_tokenizer_with_flash_attn, architectures=['Ernie4_5_ForCausalLM', 'Ernie4_5_MoeForCausalLM'], )) + + +def get_model_tokenizer_ernie_vl(*args, **kwargs): + model, processor = get_model_tokenizer_multimodal(*args, **kwargs) + if model is not None: + model.add_image_preprocess(processor) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.ernie_vl, + [ + ModelGroup([ + Model('PaddlePaddle/ERNIE-4.5-VL-28B-A3B-Base-PT', 'baidu/ERNIE-4.5-VL-28B-A3B-Base-PT'), + Model('PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT', 'baidu/ERNIE-4.5-VL-28B-A3B-PT'), + Model('PaddlePaddle/ERNIE-4.5-VL-424B-A47B-Base-PT', 'baidu/ERNIE-4.5-VL-424B-A47B-Base-PT'), + Model('PaddlePaddle/ERNIE-4.5-VL-424B-A47B-PT', 'baidu/ERNIE-4.5-VL-424B-A47B-PT'), + ]), + ], + TemplateType.ernie_vl, + get_model_tokenizer_ernie_vl, + model_arch=ModelArch.ernie_vl, + architectures=['Ernie4_5_VLMoeForConditionalGeneration'], + requires=['transformers>=4.52'], + )) diff --git a/swift/llm/model/model_arch.py b/swift/llm/model/model_arch.py index 457ac4543a..d8a62811a0 100644 --- a/swift/llm/model/model_arch.py +++ b/swift/llm/model/model_arch.py @@ -64,6 +64,7 @@ class MLLMModelArch: idefics3 = 'idefics3' got_ocr2 = 'got_ocr2' + ernie_vl = 'ernie_vl' ovis1_6 = 'ovis1_6' molmo = 'molmo' @@ -546,6 +547,13 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non aligner='model.mm_projector_vary', vision_tower='model.vision_tower_high', )) +register_model_arch( + MultiModelKeys( + MLLMModelArch.ernie_vl, + language_model='model', + aligner='model.resampler_model', + vision_tower='vision_model', + )) if transformers_ge_4_52: register_model_arch( diff --git a/swift/llm/model/utils.py b/swift/llm/model/utils.py index 883a3570e3..a088a61b62 100644 --- a/swift/llm/model/utils.py +++ b/swift/llm/model/utils.py @@ -150,7 +150,7 @@ def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]]) -> Option INF = int(1e9) max_model_len = INF - possible_keys = [ + possible_keys = { 'seq_length', # qwen, chatglm 'max_position_embeddings', # qwen1.5, llama2 'n_positions', # polylm, phi-2 @@ -160,7 +160,9 @@ def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]]) -> Option 'max_seq_len', 'max_sequence_length', 'max_seq_length', - ] + } + if getattr(config, 'model_type', None) == 'ernie4_5_moe_vl': + possible_keys.discard('max_sequence_length') for key in possible_keys: max_len_key = HfConfigFactory.get_config_attr(config, key) if max_len_key is not None: diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py index c783202e22..67f4f31286 100644 --- a/swift/llm/template/constant.py +++ b/swift/llm/template/constant.py @@ -178,6 +178,7 @@ class MLLMTemplateType: step_audio = 'step_audio' kimi_vl = 'kimi_vl' keye_vl = 'keye_vl' + ernie_vl = 'ernie_vl' idefics3 = 'idefics3' pixtral = 'pixtral' diff --git a/swift/llm/template/template/baidu.py b/swift/llm/template/template/baidu.py index b3b3d45e3b..628f0e3848 100644 --- a/swift/llm/template/template/baidu.py +++ b/swift/llm/template/template/baidu.py @@ -1,11 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Dict, List, Literal, Optional -from ..constant import LLMTemplateType +from ..base import Template +from ..constant import LLMTemplateType, MLLMTemplateType from ..register import TemplateMeta, register_template -from ..utils import Prompt - +from ..template_inputs import StdTemplateInputs +from ..utils import Context, Prompt, align_image_inputs, findall, split_tokens +import torch.nn.functional as F @dataclass class ERNIETemplateMeta(TemplateMeta): @@ -17,3 +19,42 @@ class ERNIETemplateMeta(TemplateMeta): register_template(ERNIETemplateMeta(LLMTemplateType.ernie)) + + +class ERNIETemplate(Template): + placeholder_tokens = ['<|IMAGE_PLACEHOLDER|>'] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type == 'image' + return [f'Picture {index+1}:<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>'] + + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: + encoded = super()._encode(inputs) + input_ids = encoded['input_ids'] + labels = encoded['labels'] + if inputs.images: + text = self.processor.decode(input_ids).replace('<|IMAGE_PLACEHOLDER|>', '<|image@placeholder|>') + new_inputs = self.processor( + text=[text], + images=inputs.images, + videos=inputs.videos, + padding=True, + return_tensors='pt', + ) + encoded.update(new_inputs) + new_input_ids = new_inputs['input_ids'][0].tolist() + encoded['input_ids'], encoded['labels'] = align_image_inputs(input_ids, labels, new_input_ids, + self.placeholder_tokens[0]) + encoded['position_ids'] = encoded['position_ids'][0] + return encoded + + def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: + res = {} + for key in ['images', 'grid_thw', 'image_type_ids']: + res[key] = self.concat_tensor(batch, key, 0) + res.update(super()._data_collator(batch, padding_to=padding_to)) + return res + + +register_template(ERNIETemplateMeta(MLLMTemplateType.ernie_vl, template_cls=ERNIETemplate)) diff --git a/swift/llm/template/utils.py b/swift/llm/template/utils.py index 172dcb3205..f61d4e2102 100644 --- a/swift/llm/template/utils.py +++ b/swift/llm/template/utils.py @@ -61,6 +61,16 @@ def fetch_one(element: Union[Tuple, List, Set, Dict, Any], item_type: Optional[T return element +def split_tokens(tokens: List[int], sub_tokens: Union[int, List[int]]) -> List[List[int]]: + split_idx = findall(tokens, sub_tokens) + split_idx = [-1, *split_idx, len(tokens)] + res = [] + for i in range(len(split_idx) - 1): + idx, idx_next = split_idx[i] + 1, split_idx[i + 1] + res.append(tokens[idx:idx_next]) + return res + + def findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]: """Find the index of a token in the token_list.""" if isinstance(sub_token_list, int): @@ -82,6 +92,8 @@ def align_image_inputs(input_ids: List[int], labels: List[int], new_input_ids, if isinstance(new_input_ids, torch.Tensor): new_input_ids = new_input_ids.tolist() + if labels is None: + return new_input_ids, None # Find the tokens after the image_token in input_ids, and then align them. i, j = 0, 0 while i < len(input_ids): diff --git a/tests/test_align/test_template/test_vision.py b/tests/test_align/test_template/test_vision.py index ce09f56d47..a644e58fe2 100644 --- a/tests/test_align/test_template/test_vision.py +++ b/tests/test_align/test_template/test_vision.py @@ -604,6 +604,18 @@ def test_keye_vl(): assert response == response2 +def test_ernie_vl(): + pt_engine = PtEngine('PaddlePaddle/ERNIE-4.5-VL-28B-A3B-PT') + messages = [{'role': 'user', 'content': 'What is the difference between the two images?'}] + images = [ + 'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png', + 'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png' + ] + response = _infer_model(pt_engine, messages=messages, images=images) + pt_engine.default_template.template_backend = 'jinja' + response2 = _infer_model(pt_engine, messages=messages, images=images) + assert response == response2 + if __name__ == '__main__': from swift.llm import PtEngine, RequestConfig from swift.utils import get_logger, seed_everything @@ -660,4 +672,5 @@ def test_keye_vl(): # test_kimi_vl_thinking() # test_glm4_1v() # test_gemma3n() - test_keye_vl() + # test_keye_vl() + test_ernie_vl()