Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 143 additions & 2 deletions swift/template/templates/minicpm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import math
import torch
from dataclasses import dataclass, field
from functools import partial
Expand All @@ -11,7 +12,7 @@
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, findall
from ..vision_utils import load_video_minicpmv_mplug_owl3
from ..vision_utils import load_audio, load_video_minicpmv_mplug_owl3
from .llama import Llama3TemplateMeta
from .qwen import Qwen2_5TemplateMeta, Qwen3MixedTemplateMeta, QwenTemplateMeta
from .utils import ChatmlTemplateMeta
Expand Down Expand Up @@ -240,10 +241,150 @@ def _get_new_tokens(i):
template_cls=MiniCPMV2_6Template,
))

class MiniCPMO4_5Template(MiniCPMV2_6Template):
"""MiniCPM-o-4_5 template: supports video + audio → text training.

Audio placeholder: <|audio_start|><unk>*N<|audio_end|>
Model inputs added: audio_features, audio_feature_lens, audio_bounds
Audio is truncated to 30s max; sampling_rate defaults to 16000.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring states that audio is truncated to a maximum of 30 seconds, but the class constant MAX_AUDIO_SECONDS is set to 60. To avoid confusion, the docstring should be updated to reflect the actual value used in the code.

Suggested change
Audio is truncated to 30s max; sampling_rate defaults to 16000.
Audio is truncated to 60s max; sampling_rate defaults to 16000.

"""

SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 60

def _get_audio_placeholder(self, audio_len_samples: int) -> str:
"""Compute the audio placeholder string for a given waveform length."""
pool_step = self.processor.pool_step # typically 5
hop_length = self.processor.audio_processor.hop_length # Whisper: 160
feature_lens = math.ceil(audio_len_samples / hop_length)
feature_lens = (feature_lens - 1) // 2 + 1
output_lens = max(1, (feature_lens - pool_step) // pool_step + 1)
audio_start = self.processor.tokenizer.audio_start # '<|audio_start|>'
audio_end = self.processor.tokenizer.audio_end # '<|audio_end|>'
return audio_start + '<unk>' * output_lens + audio_end

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
if media_type == 'audio':
audio = inputs.audios[index]
if isinstance(audio, str):
audio_np = load_audio(audio, sampling_rate=self.SAMPLING_RATE)
# truncate to MAX_AUDIO_SECONDS
max_samples = self.MAX_AUDIO_SECONDS * self.SAMPLING_RATE
if len(audio_np) > max_samples:
audio_np = audio_np[:max_samples]
inputs.audios[index] = audio_np
else:
audio_np = audio
return [self._get_audio_placeholder(len(audio_np))]
# video / image: delegate to parent (assert also removed for video)
load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=self.max_num_frames)
image_context = MiniCPMVTemplate.replace_tag(self, 'image', index, inputs)
if media_type == 'image':
return image_context
else: # video
return self.replace_video2image(load_video, inputs, lambda i: image_context)

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
# parent handles image/video → pixel_values, image_bound, tgt_sizes
encoded = MiniCPMV2_6Template._encode(self, inputs)

input_ids = torch.tensor(encoded['input_ids'])
audio_start_id = self.processor.tokenizer.audio_start_id
audio_end_id = self.processor.tokenizer.audio_end_id
start_positions = torch.where(input_ids == audio_start_id)[0]
end_positions = torch.where(input_ids == audio_end_id)[0]

# Fix image_bound: MiniCPMV2_6Template._encode finds ALL <unk> runs,
# but audio placeholders also use <unk> tokens. We must exclude them,
# otherwise get_vllm_embedding's torch.stack fails due to mismatched sizes.
if len(start_positions) > 0:
in_audio = torch.zeros(len(input_ids), dtype=torch.bool)
for s, e in zip(start_positions.tolist(), end_positions.tolist()):
in_audio[s:e + 1] = True
unk_token = self.processor.encode('<unk>', add_special_tokens=False)[0]
image_unk_mask = (input_ids == unk_token) & ~in_audio
indices = image_unk_mask.nonzero(as_tuple=True)[0].tolist()
if indices:
ranges = []
start = indices[0]
for i in range(1, len(indices)):
if indices[i] != indices[i - 1] + 1:
ranges.append([start, indices[i - 1] + 1])
start = indices[i]
ranges.append([start, indices[-1] + 1])
encoded['image_bound'] = [torch.tensor(ranges)]
else:
encoded['image_bound'] = [[]]

if not inputs.audios:
encoded['audio_features'] = []
encoded['audio_feature_lens'] = []
encoded['audio_bounds'] = torch.zeros((0, 2), dtype=torch.long)
return encoded

# audios already loaded as np.ndarray by replace_tag
audios = inputs.audios
audio_result = self.processor.process_audio(audios=audios, sampling_rate=self.SAMPLING_RATE)
audio_features = audio_result['audio_features'] # (total_chunks, 80, max_frames)
audio_feature_lens = audio_result['audio_feature_lens'] # [tensor([len1, ...])]

assert len(start_positions) == len(end_positions), (
f'audio_start/end token count mismatch: '
f'{len(start_positions)} vs {len(end_positions)}'
)
if len(start_positions) > 0:
audio_bounds = torch.hstack([
(start_positions + 1).unsqueeze(-1),
end_positions.unsqueeze(-1),
])
else:
audio_bounds = torch.zeros((0, 2), dtype=torch.long)

encoded['audio_features'] = audio_features
encoded['audio_feature_lens'] = audio_feature_lens
encoded['audio_bounds'] = audio_bounds
return encoded

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
inputs_embeds, _ = model.get_vllm_embedding(inputs)
return {'inputs_embeds': inputs_embeds}

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
# collate image/video fields from parent
res = {}
for k in ['pixel_values', 'image_bound', 'tgt_sizes']:
res[k] = self.gather_list(batch, k)
res.update(Template._data_collator(self, batch, padding_to=padding_to))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for collating image and video fields is duplicated from the parent class MiniCPMVTemplate. You can simplify this by calling super()._data_collator(...) to reuse the parent's implementation, which improves code reuse and maintainability.

Suggested change
# collate image/video fields from parent
res = {}
for k in ['pixel_values', 'image_bound', 'tgt_sizes']:
res[k] = self.gather_list(batch, k)
res.update(Template._data_collator(self, batch, padding_to=padding_to))
res = super()._data_collator(batch, padding_to=padding_to)


# collate audio fields
audio_features_list = [
b['audio_features'] for b in batch
if isinstance(b.get('audio_features'), torch.Tensor) and b['audio_features'].numel() > 0
]
if audio_features_list:
max_frames = max(af.shape[-1] for af in audio_features_list)
padded = []
for af in audio_features_list:
pad_len = max_frames - af.shape[-1]
padded.append(torch.nn.functional.pad(af, (0, pad_len)))
res['audio_features'] = torch.cat(padded, dim=0)
else:
res['audio_features'] = []

# audio_feature_lens: list of tensors (one per batch item with audio)
res['audio_feature_lens'] = [
b['audio_feature_lens'] for b in batch
]
# audio_bounds: list of tensors (one per batch item)
res['audio_bounds'] = [b['audio_bounds'] for b in batch]
return res


register_template(
Qwen3MixedTemplateMeta(
MLLMTemplateType.minicpmo4_5,
template_cls=MiniCPMV2_6Template,
template_cls=MiniCPMO4_5Template,
is_thinking=True,
))

Expand Down