|
8 | 8 | import torch.nn.functional as F |
9 | 9 | import transformers |
10 | 10 | from packaging import version |
| 11 | +from torch import nn |
11 | 12 |
|
12 | 13 | from swift.llm import get_packed_seq_params, to_device, to_float_dtype |
13 | 14 | from swift.utils import get_env_args, is_deepspeed_enabled |
|
17 | 18 | from ..template_inputs import StdTemplateInputs |
18 | 19 | from ..template_meta import TemplateMeta |
19 | 20 | from ..utils import Context, Word, findall |
20 | | -from ..vision_utils import load_audio, load_batch, load_video_ovis2 |
| 21 | +from ..vision_utils import load_audio, load_batch, load_video_ovis2, load_video_ovis2_5 |
21 | 22 | from .llama import Llama3TemplateMeta |
22 | 23 | from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta, ThinkingTemplate |
23 | 24 |
|
@@ -736,6 +737,86 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int |
736 | 737 | )) |
737 | 738 |
|
738 | 739 |
|
| 740 | +class Ovis2_5Template(ThinkingTemplate): |
| 741 | + num_frames = 8 |
| 742 | + use_model = True |
| 743 | + skip_prompt = False |
| 744 | + |
| 745 | + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, |
| 746 | + inputs: StdTemplateInputs) -> List[Context]: |
| 747 | + if media_type == 'image': |
| 748 | + return [[-200], '\n'] |
| 749 | + elif media_type == 'video': |
| 750 | + num_frames = get_env_args('num_frames', int, self.num_frames) |
| 751 | + inputs.images = load_video_ovis2_5(inputs.videos[index], num_frames) |
| 752 | + return [[-200], '\n'] |
| 753 | + |
| 754 | + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| 755 | + min_pixels = get_env_args('min_pixels', int, 448 * 448) |
| 756 | + max_pixels = get_env_args('max_pixels', int, 1344 * 1792) |
| 757 | + video_max_pixels = get_env_args('video_max_pixels', int, 896 * 896) |
| 758 | + encoded = super()._encode(inputs) |
| 759 | + images = inputs.images |
| 760 | + input_ids = encoded['input_ids'] |
| 761 | + visual_tokenizer = self.model.visual_tokenizer |
| 762 | + idx_list = findall(input_ids, [-200]) |
| 763 | + if inputs.videos: |
| 764 | + assert len(inputs.videos) == 1, 'only support single video' |
| 765 | + encoded['pixel_values'], encoded['grid_thws'] = visual_tokenizer.preprocess( |
| 766 | + video=inputs.images, min_pixels=min_pixels, max_pixels=video_max_pixels) |
| 767 | + num_video_tokens = encoded['grid_thws'].prod(dim=-1) |
| 768 | + num_video_tokens //= visual_tokenizer.vit.config.hidden_stride**2 |
| 769 | + num_video_tokens //= visual_tokenizer.vit.config.temporal_patch_size |
| 770 | + |
| 771 | + def _get_new_tokens(i): |
| 772 | + token_len = num_video_tokens[i].item() |
| 773 | + return [-303] + [-300] * token_len + [-304] |
| 774 | + |
| 775 | + input_ids, encoded['labels'], encoded['loss_scale'] = self._extend_tokens( |
| 776 | + input_ids, encoded['labels'], encoded['loss_scale'], idx_list, _get_new_tokens) |
| 777 | + elif images: |
| 778 | + pixel_values, grid_thws = zip( |
| 779 | + *(visual_tokenizer.preprocess(image=image, min_pixels=min_pixels, max_pixels=max_pixels) |
| 780 | + for image in images)) |
| 781 | + encoded['pixel_values'] = torch.cat(pixel_values, dim=0) |
| 782 | + encoded['grid_thws'] = torch.cat(grid_thws, dim=0) |
| 783 | + |
| 784 | + num_image_atoms = encoded['grid_thws'].prod(dim=-1) |
| 785 | + num_image_atoms //= visual_tokenizer.vit.config.hidden_stride**2 |
| 786 | + num_image_atoms //= visual_tokenizer.vit.config.temporal_patch_size |
| 787 | + |
| 788 | + def _get_new_tokens(i): |
| 789 | + token_len = num_image_atoms[i].item() |
| 790 | + return [-301] + [-300] * token_len + [-302] |
| 791 | + |
| 792 | + input_ids, encoded['labels'], encoded['loss_scale'] = self._extend_tokens( |
| 793 | + input_ids, encoded['labels'], encoded['loss_scale'], idx_list, _get_new_tokens) |
| 794 | + |
| 795 | + encoded['input_ids'] = input_ids |
| 796 | + return encoded |
| 797 | + |
| 798 | + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: |
| 799 | + inputs_embeds = model.merge_multimodal( |
| 800 | + input_ids=inputs['input_ids'], |
| 801 | + pixel_values=inputs.pop('pixel_values', None), |
| 802 | + grid_thws=inputs.pop('grid_thws', None)) |
| 803 | + return {'inputs_embeds': inputs_embeds} |
| 804 | + |
| 805 | + def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| 806 | + res = super()._data_collator_mm_data(batch) |
| 807 | + grid_thws = self.concat_tensor(batch, 'grid_thws', 0) |
| 808 | + if grid_thws is not None: |
| 809 | + res['grid_thws'] = grid_thws |
| 810 | + return res |
| 811 | + |
| 812 | + |
| 813 | +register_template(QwenTemplateMeta( |
| 814 | + MLLMTemplateType.ovis2_5, |
| 815 | + template_cls=Ovis2_5Template, |
| 816 | + default_system=None, |
| 817 | +)) |
| 818 | + |
| 819 | + |
739 | 820 | @dataclass |
740 | 821 | class MarcoO1TemplateMeta(QwenTemplateMeta): |
741 | 822 | default_system: Optional[str] = """ |
|
0 commit comments