diff --git a/swift/model/constant.py b/swift/model/constant.py index 204cc3f6eb..00746caa13 100644 --- a/swift/model/constant.py +++ b/swift/model/constant.py @@ -239,6 +239,7 @@ class MLLMModelType: idefics3 = 'idefics3' paligemma = 'paligemma' molmo = 'molmo' + molmo2 = 'molmo2' molmoe = 'molmoe' pixtral = 'pixtral' megrez_omni = 'megrez_omni' diff --git a/swift/model/models/__init__.py b/swift/model/models/__init__.py index 5a0661cfe1..8c67705142 100644 --- a/swift/model/models/__init__.py +++ b/swift/model/models/__init__.py @@ -1,3 +1,3 @@ from . import (baai, baichuan, baidu, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, - microsoft, minicpm, minimax, mistral, mllm, moonshot, mplug, openbuddy, qwen, seed, skywork, stepfun, - telechat, tencent, valley, yi) + microsoft, minicpm, minimax, mistral, mllm, molmo2, moonshot, mplug, openbuddy, qwen, seed, skywork, + stepfun, telechat, tencent, valley, yi) diff --git a/swift/model/models/molmo2.py b/swift/model/models/molmo2.py new file mode 100644 index 0000000000..a26b840969 --- /dev/null +++ b/swift/model/models/molmo2.py @@ -0,0 +1,114 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import transformers +from contextlib import contextmanager +from packaging import version +from transformers import PreTrainedModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from swift.template import TemplateType +from swift.utils import get_logger +from ..constant import MLLMModelType +from ..model_arch import ModelArch +from ..model_meta import Model, ModelGroup, ModelMeta +from ..patcher import patch_output_clone +from ..register import ModelLoader, register_model + +logger = get_logger() + + +class Molmo2Loader(ModelLoader): + + @staticmethod + @contextmanager + def _patch_processor_optional_attributes_compat(): + """Restrict ProcessorMixin compat to Molmo2 processor loading only.""" + if version.parse(transformers.__version__) < version.parse('5.0.0.dev'): + yield + return + try: + from transformers.processing_utils import ProcessorMixin + except Exception: + yield + return + + origin_init = ProcessorMixin.__init__ + + def _patched_init(self, *args, **kwargs): + optional_attributes = getattr(self, 'optional_attributes', None) or [] + optional_values = {} + for key in optional_attributes: + if key in {'chat_template', 'audio_tokenizer'}: + continue + if key in kwargs: + optional_values[key] = kwargs.pop(key) + + origin_init(self, *args, **kwargs) + + for key in optional_attributes: + if key in {'chat_template', 'audio_tokenizer'}: + continue + if key in optional_values: + setattr(self, key, optional_values[key]) + elif not hasattr(self, key): + setattr(self, key, None) + + ProcessorMixin.__init__ = _patched_init + try: + yield + finally: + ProcessorMixin.__init__ = origin_init + + @staticmethod + def _patch_vision_pooling_attention(model: PreTrainedModel) -> None: + inner_model = getattr(model, 'model', None) + if inner_model is None: + return + + vision_backbone = getattr(inner_model, 'vision_backbone', None) + if vision_backbone is None: + return + pooling = getattr(vision_backbone, 'image_pooling_2d', None) + if pooling is None or getattr(pooling, 'attn_implementation', None) != 'flash_attention_2': + return + + pooling.attn_implementation = 'sdpa' + adapter_config = getattr(vision_backbone, 'adapter_config', None) + if adapter_config is not None and getattr(adapter_config, 'attn_implementation', None) == 'flash_attention_2': + adapter_config.attn_implementation = 'sdpa' + logger.info('Set Molmo2 vision_backbone.image_pooling_2d attention to `sdpa` to avoid ' + 'flash-attn varlen failures on padded video batches.') + + def get_processor(self, model_dir, config): + with self._patch_processor_optional_attributes_compat(): + return super().get_processor(model_dir, config) + + def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: + from transformers import AutoModelForImageTextToText + model_cls = get_class_from_dynamic_module('modeling_molmo2.Molmo2ForConditionalGeneration', model_dir) + no_split_modules = getattr(model_cls, '_no_split_modules', []) or [] + if 'MolmoSequentialBlock' not in no_split_modules: + model_cls._no_split_modules = no_split_modules + ['MolmoSequentialBlock'] + self.auto_model_cls = self.auto_model_cls or AutoModelForImageTextToText + model = super().get_model(model_dir, *args, **kwargs) + patch_output_clone(model.model.transformer.wte) + self._patch_vision_pooling_attention(model) + return model + + +register_model( + ModelMeta( + MLLMModelType.molmo2, + [ + ModelGroup([ + Model('LLM-Research/Molmo2-4B', 'allenai/Molmo2-4B'), + Model('LLM-Research/Molmo2-8B', 'allenai/Molmo2-8B'), + Model('LLM-Research/Molmo2-O-7B', 'allenai/Molmo2-O-7B'), + ]), + ], + Molmo2Loader, + template=TemplateType.molmo2, + model_arch=ModelArch.molmo, + architectures=['Molmo2ForConditionalGeneration'], + tags=['vision', 'video'], + requires=['transformers>=4.57.1', 'decord'], + )) diff --git a/swift/template/constant.py b/swift/template/constant.py index ab20b50233..ecd2d8a44e 100644 --- a/swift/template/constant.py +++ b/swift/template/constant.py @@ -243,6 +243,7 @@ class MLLMTemplateType: phi4_multimodal = 'phi4_multimodal' florence = 'florence' molmo = 'molmo' + molmo2 = 'molmo2' megrez_omni = 'megrez_omni' valley = 'valley' gemma3_vision = 'gemma3_vision' diff --git a/swift/template/templates/__init__.py b/swift/template/templates/__init__.py index 1af552c8ed..5f7f20b2e4 100644 --- a/swift/template/templates/__init__.py +++ b/swift/template/templates/__init__.py @@ -1,3 +1,3 @@ from . import (baai, baidu, bert, deepseek, dots, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm, - megrez, microsoft, midashenglm, minicpm, minimax, minimind, mistral, molmo, moonshot, mplug, openbuddy, - pixtral, qwen, seed, stepfun, tencent, valley, yi) + megrez, microsoft, midashenglm, minicpm, minimax, minimind, mistral, molmo, molmo2, moonshot, mplug, + openbuddy, pixtral, qwen, seed, stepfun, tencent, valley, yi) diff --git a/swift/template/templates/molmo2.py b/swift/template/templates/molmo2.py new file mode 100644 index 0000000000..15c11fa5d0 --- /dev/null +++ b/swift/template/templates/molmo2.py @@ -0,0 +1,277 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import inspect +import numpy as np +import re +from PIL import Image +from typing import Any, Dict, List, Literal, Tuple + +from ..base import Template +from ..constant import MLLMTemplateType +from ..register import TemplateMeta, register_template +from ..template_inputs import StdTemplateInputs +from ..utils import Context + + +class Molmo2Template(Template): + """Native Molmo2 template for image and video understanding.""" + + use_model = True + + placeholder_tokens = [ + '<|image|>', + '<|video|>', + '', + '', + '', + ] + + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, + inputs: StdTemplateInputs) -> List[Context]: + if media_type == 'image': + return ['<|image|>'] + if media_type == 'video': + return ['<|video|>'] + return [] + + @staticmethod + def _load_video_descriptor(video_item: Any) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]: + if not isinstance(video_item, dict): + raise TypeError('Molmo2 expects a video descriptor dict produced by the dataset preprocessor.') + frame_paths = video_item.get('frame_paths') or [] + timestamps = video_item.get('timestamps') or [] + if not frame_paths or not timestamps or len(frame_paths) != len(timestamps): + raise ValueError('Molmo2 video descriptor requires aligned `frame_paths` and `timestamps`.') + + frames = [] + for frame_path in frame_paths: + with Image.open(frame_path) as image: + frames.append(np.asarray(image.convert('RGB'))) + frame_array = np.stack(frames, axis=0) + timestamp_array = np.asarray(timestamps, dtype=np.float32) + metadata = { + 'frame_paths': frame_paths, + 'source_video': video_item.get('source_video'), + 'num_frames': len(frame_paths), + } + return frame_array, timestamp_array, metadata + + @staticmethod + def _build_messages_for_processor(inputs: StdTemplateInputs) -> List[Dict[str, Any]]: + messages = copy.deepcopy(inputs.messages) + image_idx = 0 + video_idx = 0 + for message in messages: + content = message.get('content', '') + structured_content: List[Dict[str, Any]] = [] + if not isinstance(content, str): + message['content'] = content + continue + for chunk in re.split(r'(|