Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swift/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class MLLMModelType:
idefics3 = 'idefics3'
paligemma = 'paligemma'
molmo = 'molmo'
molmo2 = 'molmo2'
molmoe = 'molmoe'
pixtral = 'pixtral'
megrez_omni = 'megrez_omni'
Expand Down
4 changes: 2 additions & 2 deletions swift/model/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import (baai, baichuan, baidu, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba,
microsoft, minicpm, minimax, mistral, mllm, moonshot, mplug, openbuddy, qwen, seed, skywork, stepfun,
telechat, tencent, valley, yi)
microsoft, minicpm, minimax, mistral, mllm, molmo2, moonshot, mplug, openbuddy, qwen, seed, skywork,
stepfun, telechat, tencent, valley, yi)
112 changes: 112 additions & 0 deletions swift/model/models/molmo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import transformers
from contextlib import contextmanager
from packaging import version
from transformers import PreTrainedModel
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from swift.template import TemplateType
from swift.utils import get_logger
from ..constant import MLLMModelType
from ..model_arch import ModelArch
from ..model_meta import Model, ModelGroup, ModelMeta
from ..patcher import patch_output_clone
from ..register import ModelLoader, register_model

logger = get_logger()


class Molmo2Loader(ModelLoader):

@staticmethod
@contextmanager
def _patch_processor_optional_attributes_compat():
"""Restrict ProcessorMixin compat to Molmo2 processor loading only."""
if version.parse(transformers.__version__) < version.parse('5.0.0.dev'):
yield
return
try:
from transformers.processing_utils import ProcessorMixin
except Exception:
yield
return

origin_init = ProcessorMixin.__init__

def _patched_init(self, *args, **kwargs):
optional_attributes = getattr(self, 'optional_attributes', None) or []
optional_values = {}
for key in optional_attributes:
if key in {'chat_template', 'audio_tokenizer'}:
continue
if key in kwargs:
optional_values[key] = kwargs.pop(key)

origin_init(self, *args, **kwargs)

for key in optional_attributes:
if key in {'chat_template', 'audio_tokenizer'}:
continue
if key in optional_values:
setattr(self, key, optional_values[key])
elif not hasattr(self, key):
setattr(self, key, None)

ProcessorMixin.__init__ = _patched_init
try:
yield
finally:
ProcessorMixin.__init__ = origin_init

@staticmethod
def _patch_vision_pooling_attention(model: PreTrainedModel) -> None:
inner_model = getattr(model, 'model', None)
if inner_model is None:
return

vision_backbone = getattr(inner_model, 'vision_backbone', None)
if vision_backbone is None:
return
pooling = getattr(vision_backbone, 'image_pooling_2d', None)
if pooling is None or getattr(pooling, 'attn_implementation', None) != 'flash_attention_2':
return

pooling.attn_implementation = 'sdpa'
adapter_config = getattr(vision_backbone, 'adapter_config', None)
if adapter_config is not None and getattr(adapter_config, 'attn_implementation', None) == 'flash_attention_2':
adapter_config.attn_implementation = 'sdpa'
logger.info('Set Molmo2 vision_backbone.image_pooling_2d attention to `sdpa` to avoid '
'flash-attn varlen failures on padded video batches.')

def get_processor(self, model_dir, config):
with self._patch_processor_optional_attributes_compat():
return super().get_processor(model_dir, config)

def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
from transformers import AutoModelForImageTextToText
model_cls = get_class_from_dynamic_module('modeling_molmo2.Molmo2ForConditionalGeneration', model_dir)
model_cls._no_split_modules = getattr(model_cls, '_no_split_modules', []) or ['MolmoSequentialBlock']
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for setting _no_split_modules will overwrite the default list if it's empty, but it won't append to it if it already contains other modules. It's safer to ensure MolmoSequentialBlock is included in the list without discarding existing entries.

Suggested change
model_cls._no_split_modules = getattr(model_cls, '_no_split_modules', []) or ['MolmoSequentialBlock']
no_split_modules = getattr(model_cls, '_no_split_modules', []) or []
if 'MolmoSequentialBlock' not in no_split_modules:
model_cls._no_split_modules = no_split_modules + ['MolmoSequentialBlock']

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 6eaf502: I now preserve existing _no_split_modules entries and only append MolmoSequentialBlock when it is missing.

self.auto_model_cls = self.auto_model_cls or AutoModelForImageTextToText
model = super().get_model(model_dir, *args, **kwargs)
patch_output_clone(model.model.transformer.wte)
self._patch_vision_pooling_attention(model)
return model


register_model(
ModelMeta(
MLLMModelType.molmo2,
[
ModelGroup([
Model('LLM-Research/Molmo2-4B', 'allenai/Molmo2-4B'),
Model('LLM-Research/Molmo2-8B', 'allenai/Molmo2-8B'),
Model('LLM-Research/Molmo2-O-7B', 'allenai/Molmo2-O-7B'),
]),
],
Molmo2Loader,
template=TemplateType.molmo2,
model_arch=ModelArch.molmo,
architectures=['Molmo2ForConditionalGeneration'],
tags=['vision', 'video'],
requires=['transformers>=4.57.1', 'decord'],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The version requirement transformers>=4.57.1 appears to be a typo, as this version does not exist yet (the current stable version is around 4.48). Molmo models typically require transformers>=4.45.0.

Suggested change
requires=['transformers>=4.57.1', 'decord'],
requires=['transformers>=4.45.0', 'decord'],

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept transformers>=4.57.1 here. 4.57.1 is a real released version now, and the local Molmo2 smoke validation for this PR succeeded with transformers==4.57.3. I would prefer to keep the newer minimum for the current Molmo2 processor/runtime path rather than relax it to 4.45.0 without additional compatibility coverage.

))
1 change: 1 addition & 0 deletions swift/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class MLLMTemplateType:
phi4_multimodal = 'phi4_multimodal'
florence = 'florence'
molmo = 'molmo'
molmo2 = 'molmo2'
megrez_omni = 'megrez_omni'
valley = 'valley'
gemma3_vision = 'gemma3_vision'
Expand Down
4 changes: 2 additions & 2 deletions swift/template/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import (baai, baidu, bert, deepseek, dots, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm,
megrez, microsoft, midashenglm, minicpm, minimax, minimind, mistral, molmo, moonshot, mplug, openbuddy,
pixtral, qwen, seed, stepfun, tencent, valley, yi)
megrez, microsoft, midashenglm, minicpm, minimax, minimind, mistral, molmo, molmo2, moonshot, mplug,
openbuddy, pixtral, qwen, seed, stepfun, tencent, valley, yi)
Loading