Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 139 additions & 2 deletions swift/template/templates/minicpm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import math
import torch
from dataclasses import dataclass, field
from functools import partial
Expand All @@ -11,7 +12,7 @@
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, findall
from ..vision_utils import load_video_minicpmv_mplug_owl3
from ..vision_utils import load_audio, load_video_minicpmv_mplug_owl3
from .llama import Llama3TemplateMeta
from .qwen import Qwen2_5TemplateMeta, Qwen3MixedTemplateMeta, QwenTemplateMeta
from .utils import ChatmlTemplateMeta
Expand Down Expand Up @@ -240,10 +241,146 @@ def _get_new_tokens(i):
template_cls=MiniCPMV2_6Template,
))

class MiniCPMO4_5Template(MiniCPMV2_6Template):
"""MiniCPM-o-4_5 template: supports video + audio → text training.

Audio placeholder: <|audio_start|><unk>*N<|audio_end|>
Model inputs added: audio_features, audio_feature_lens, audio_bounds
Audio is truncated to 60s max; sampling_rate defaults to 16000.
"""

SAMPLING_RATE = 16000
MAX_AUDIO_SECONDS = 60

def _get_audio_placeholder(self, audio_len_samples: int) -> str:
"""Compute the audio placeholder string for a given waveform length."""
pool_step = self.processor.pool_step # typically 5
hop_length = self.processor.audio_processor.hop_length # Whisper: 160
feature_lens = math.ceil(audio_len_samples / hop_length)
feature_lens = (feature_lens - 1) // 2 + 1
output_lens = max(1, (feature_lens - pool_step) // pool_step + 1)
audio_start = self.processor.tokenizer.audio_start # '<|audio_start|>'
audio_end = self.processor.tokenizer.audio_end # '<|audio_end|>'
return audio_start + '<unk>' * output_lens + audio_end

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
if media_type == 'audio':
audio = inputs.audios[index]
if isinstance(audio, str):
audio_np = load_audio(audio, sampling_rate=self.SAMPLING_RATE)
# truncate to MAX_AUDIO_SECONDS
max_samples = self.MAX_AUDIO_SECONDS * self.SAMPLING_RATE
if len(audio_np) > max_samples:
audio_np = audio_np[:max_samples]
inputs.audios[index] = audio_np
else:
audio_np = audio
return [self._get_audio_placeholder(len(audio_np))]
# video / image: delegate to parent (assert also removed for video)
load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=self.max_num_frames)
image_context = MiniCPMVTemplate.replace_tag(self, 'image', index, inputs)
if media_type == 'image':
return image_context
else: # video
return self.replace_video2image(load_video, inputs, lambda i: image_context)

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
# parent handles image/video → pixel_values, image_bound, tgt_sizes
encoded = MiniCPMV2_6Template._encode(self, inputs)

input_ids = torch.tensor(encoded['input_ids'])
audio_start_id = self.processor.tokenizer.audio_start_id
audio_end_id = self.processor.tokenizer.audio_end_id
start_positions = torch.where(input_ids == audio_start_id)[0]
end_positions = torch.where(input_ids == audio_end_id)[0]

# Fix image_bound: MiniCPMV2_6Template._encode finds ALL <unk> runs,
# but audio placeholders also use <unk> tokens. We must exclude them,
# otherwise get_vllm_embedding's torch.stack fails due to mismatched sizes.
if len(start_positions) > 0:
in_audio = torch.zeros(len(input_ids), dtype=torch.bool)
for s, e in zip(start_positions.tolist(), end_positions.tolist()):
in_audio[s:e + 1] = True
unk_token = self.processor.encode('<unk>', add_special_tokens=False)[0]
image_unk_mask = (input_ids == unk_token) & ~in_audio
indices = image_unk_mask.nonzero(as_tuple=True)[0].tolist()
if indices:
ranges = []
start = indices[0]
for i in range(1, len(indices)):
if indices[i] != indices[i - 1] + 1:
ranges.append([start, indices[i - 1] + 1])
start = indices[i]
ranges.append([start, indices[-1] + 1])
encoded['image_bound'] = [torch.tensor(ranges)]
else:
encoded['image_bound'] = [[]]

if not inputs.audios:
encoded['audio_features'] = []
encoded['audio_feature_lens'] = []
encoded['audio_bounds'] = torch.zeros((0, 2), dtype=torch.long)
return encoded

# audios already loaded as np.ndarray by replace_tag
audios = inputs.audios
audio_result = self.processor.process_audio(audios=audios, sampling_rate=self.SAMPLING_RATE)
audio_features = audio_result['audio_features'] # (total_chunks, 80, max_frames)
audio_feature_lens = audio_result['audio_feature_lens'] # [tensor([len1, ...])]

assert len(start_positions) == len(end_positions), (
f'audio_start/end token count mismatch: '
f'{len(start_positions)} vs {len(end_positions)}'
)
if len(start_positions) > 0:
audio_bounds = torch.hstack([
(start_positions + 1).unsqueeze(-1),
end_positions.unsqueeze(-1),
])
else:
audio_bounds = torch.zeros((0, 2), dtype=torch.long)

encoded['audio_features'] = audio_features
encoded['audio_feature_lens'] = audio_feature_lens
encoded['audio_bounds'] = audio_bounds
return encoded

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
inputs_embeds, _ = model.get_vllm_embedding(inputs)
return {'inputs_embeds': inputs_embeds}

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super()._data_collator(batch, padding_to=padding_to)

# collate audio fields
audio_features_list = [
b['audio_features'] for b in batch
if isinstance(b.get('audio_features'), torch.Tensor) and b['audio_features'].numel() > 0
]
if audio_features_list:
max_frames = max(af.shape[-1] for af in audio_features_list)
padded = []
for af in audio_features_list:
pad_len = max_frames - af.shape[-1]
padded.append(torch.nn.functional.pad(af, (0, pad_len)))
res['audio_features'] = torch.cat(padded, dim=0)
else:
res['audio_features'] = []

# audio_feature_lens: list of tensors (one per batch item with audio)
res['audio_feature_lens'] = [
b['audio_feature_lens'] for b in batch
]
# audio_bounds: list of tensors (one per batch item)
res['audio_bounds'] = [b['audio_bounds'] for b in batch]
return res


register_template(
Qwen3MixedTemplateMeta(
MLLMTemplateType.minicpmo4_5,
template_cls=MiniCPMV2_6Template,
template_cls=MiniCPMO4_5Template,
is_thinking=True,
))

Expand Down
Loading