Skip to content

[Feature] Add Molmo2 model and template support#9063

Draft
Kagura-0001 wants to merge 2 commits intomodelscope:mainfrom
Kagura-0001:codex/add-molmo2-support-pr
Draft

[Feature] Add Molmo2 model and template support#9063
Kagura-0001 wants to merge 2 commits intomodelscope:mainfrom
Kagura-0001:codex/add-molmo2-support-pr

Conversation

@Kagura-0001
Copy link
Copy Markdown

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR adds initial Molmo2 support to ms-swift.

What changed

  • register molmo2 as a new MLLM model type and template type
  • add swift/model/models/molmo2.py
  • add swift/template/templates/molmo2.py
  • wire Molmo2 into model/template package registries
  • register the following model ids:
    • allenai/Molmo2-4B
    • allenai/Molmo2-8B
    • allenai/Molmo2-O-7B
  • add a unittest to verify Molmo2 registration and template wiring

Notes

  • the loader includes a ProcessorMixin compatibility patch for newer transformers
  • the loader also switches Molmo2 video pooling attention from flash_attention_2 to sdpa when needed to avoid padded video batch failures

Validation

  • pre-commit run --all-files
  • python tests/run.py --test_dir tests/general --pattern test_model.py
  • local smoke test with /mnt/bn/strategy-mllm-train/user/weisong/repo/motion_benchmarks/pretrained_models/Molmo2-4B confirmed get_model_processor(..., load_model=False) and template.encode(...) succeed

Related issue

Experiment results

N/A

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Molmo2 model family, including registration of 4B, 8B, and O-7B variants, along with a dedicated template for image and video understanding. Key additions include the Molmo2Loader with compatibility patches for transformers and vision attention, and the Molmo2Template for handling multi-modal inputs. Feedback focuses on correcting a version requirement typo, improving the robustness of module splitting logic, preventing potential division-by-zero errors in FPS calculation, and replacing assertions with explicit value errors for input validation.

model_arch=ModelArch.molmo,
architectures=['Molmo2ForConditionalGeneration'],
tags=['vision', 'video'],
requires=['transformers>=4.57.1', 'decord'],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The version requirement transformers>=4.57.1 appears to be a typo, as this version does not exist yet (the current stable version is around 4.48). Molmo models typically require transformers>=4.45.0.

Suggested change
requires=['transformers>=4.57.1', 'decord'],
requires=['transformers>=4.45.0', 'decord'],

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept transformers>=4.57.1 here. 4.57.1 is a real released version now, and the local Molmo2 smoke validation for this PR succeeded with transformers==4.57.3. I would prefer to keep the newer minimum for the current Molmo2 processor/runtime path rather than relax it to 4.45.0 without additional compatibility coverage.

def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
from transformers import AutoModelForImageTextToText
model_cls = get_class_from_dynamic_module('modeling_molmo2.Molmo2ForConditionalGeneration', model_dir)
model_cls._no_split_modules = getattr(model_cls, '_no_split_modules', []) or ['MolmoSequentialBlock']
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for setting _no_split_modules will overwrite the default list if it's empty, but it won't append to it if it already contains other modules. It's safer to ensure MolmoSequentialBlock is included in the list without discarding existing entries.

Suggested change
model_cls._no_split_modules = getattr(model_cls, '_no_split_modules', []) or ['MolmoSequentialBlock']
no_split_modules = getattr(model_cls, '_no_split_modules', []) or []
if 'MolmoSequentialBlock' not in no_split_modules:
model_cls._no_split_modules = no_split_modules + ['MolmoSequentialBlock']

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 6eaf502: I now preserve existing _no_split_modules entries and only append MolmoSequentialBlock when it is missing.

if len(timestamps) <= 1:
fps = 1.0
else:
fps = 1.0 / float(np.median(np.diff(timestamps)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calculating fps by dividing by the median of timestamp differences can lead to a ZeroDivisionError if the input timestamps are identical or if the median difference is zero. It's safer to add a check or a fallback value.

Suggested change
fps = 1.0 / float(np.median(np.diff(timestamps)))
median_diff = np.median(np.diff(timestamps))
fps = 1.0 / float(median_diff) if median_diff > 0 else 1.0

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 6eaf502: I added a guard for zero/non-positive median timestamp deltas and fall back to fps = 1.0 in that case.

media_inputs.update(image_inputs)

if inputs.videos:
assert len(inputs.videos) == 1, 'Molmo2 currently only supports single-video samples.'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for input validation is discouraged as it can be optimized away in production (when running with -O). It is better to raise a ValueError to provide a clear error message to the user.

Suggested change
assert len(inputs.videos) == 1, 'Molmo2 currently only supports single-video samples.'
if len(inputs.videos) != 1:
raise ValueError('Molmo2 currently only supports single-video samples.')

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 6eaf502: I replaced the assert with an explicit ValueError so the validation is preserved in optimized runs as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

支持Molmo2计划:Molmo2 is a family of open VLM developed by the Allen Institute。

1 participant