|
| 1 | +# Copyright (c) Alibaba, Inc. and its affiliates. |
| 2 | +import os |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from typing import Any, Dict, List, Literal |
| 5 | + |
| 6 | +import torch |
| 7 | +from transformers.dynamic_module_utils import get_class_from_dynamic_module |
| 8 | + |
| 9 | +from swift.llm import to_device |
| 10 | +from swift.utils import is_deepspeed_enabled |
| 11 | +from ..base import Template |
| 12 | +from ..constant import MLLMTemplateType |
| 13 | +from ..register import register_template |
| 14 | +from ..template_inputs import StdTemplateInputs |
| 15 | +from ..utils import Context, Word, findall |
| 16 | +from .qwen import Qwen2VLTemplate |
| 17 | +from .utils import ChatmlTemplateMeta |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class KeyeTemplateMeta(ChatmlTemplateMeta): |
| 22 | + auto_add_bos: bool = False |
| 23 | + stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>']) |
| 24 | + |
| 25 | + |
| 26 | +class KeyeVLTemplate(Template): |
| 27 | + image_token_id = 151655 |
| 28 | + video_token_id = 151656 |
| 29 | + placeholder_tokens = ['<|image_pad|>', '<|video_pad|>'] |
| 30 | + |
| 31 | + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, |
| 32 | + inputs: StdTemplateInputs) -> List[Context]: |
| 33 | + from keye_vl_utils import fetch_image, fetch_video |
| 34 | + # from qwen_vl_utils import fetch_image, fetch_video |
| 35 | + assert media_type in {'image', 'video'} |
| 36 | + if media_type == 'image': |
| 37 | + inputs.images[index] = fetch_image({'image': inputs.images[index]}) |
| 38 | + if getattr(self, 'mode', None) == 'lmdeploy': |
| 39 | + return ['<|vision_start|>', [-100], '<|vision_end|>'] |
| 40 | + else: |
| 41 | + return ['<|vision_start|><|image_pad|><|vision_end|>'] |
| 42 | + else: |
| 43 | + video = inputs.videos[index] |
| 44 | + if os.path.isdir(video): |
| 45 | + video = [os.path.join(video, fname) for fname in os.listdir(video)] |
| 46 | + video = fetch_video({'video': video}) |
| 47 | + if isinstance(video, torch.Tensor): |
| 48 | + video = video.to(torch.uint8) |
| 49 | + inputs.videos[index] = video |
| 50 | + return ['<|vision_start|><|video_pad|><|vision_end|>'] |
| 51 | + |
| 52 | + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| 53 | + from keye_vl_utils import vision_process |
| 54 | + encoded = super()._encode(inputs) |
| 55 | + processor = self.processor |
| 56 | + input_ids = encoded['input_ids'] |
| 57 | + labels = encoded['labels'] |
| 58 | + images = inputs.images |
| 59 | + videos = inputs.videos |
| 60 | + for media_type in ['images', 'videos']: |
| 61 | + if locals()[media_type]: |
| 62 | + if media_type == 'images': |
| 63 | + media_token = self.image_token_id |
| 64 | + media_inputs = processor.image_processor( |
| 65 | + images=images, videos=None, return_tensors='pt', do_resize=False) |
| 66 | + media_grid_thw = media_inputs['image_grid_thw'] |
| 67 | + else: |
| 68 | + if hasattr(processor, 'video_processor'): |
| 69 | + processor_func = processor.video_processor |
| 70 | + else: |
| 71 | + processor_func = processor.image_processor |
| 72 | + media_inputs = processor_func(images=None, videos=videos, return_tensors='pt', do_resize=False) |
| 73 | + media_grid_thw = media_inputs['video_grid_thw'] |
| 74 | + media_token = self.video_token_id |
| 75 | + media_inputs['second_per_grid_ts'] = [ |
| 76 | + processor.image_processor.temporal_patch_size / vision_process.FPS |
| 77 | + ] * len(media_grid_thw) |
| 78 | + idx_list = findall(input_ids, media_token) |
| 79 | + merge_length = processor.image_processor.merge_size**2 |
| 80 | + |
| 81 | + def _get_new_tokens(i): |
| 82 | + token_len = (media_grid_thw[i].prod() // merge_length) |
| 83 | + return [media_token] * token_len |
| 84 | + |
| 85 | + input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens) |
| 86 | + encoded.update(media_inputs) |
| 87 | + |
| 88 | + encoded['input_ids'] = input_ids |
| 89 | + encoded['labels'] = labels |
| 90 | + return encoded |
| 91 | + |
| 92 | + def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| 93 | + res = super()._data_collator_mm_data(batch) |
| 94 | + second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts') |
| 95 | + if second_per_grid_ts: |
| 96 | + res['second_per_grid_ts'] = second_per_grid_ts |
| 97 | + for media_type in ['image', 'video']: |
| 98 | + grid_thw = self.concat_tensor(batch, f'{media_type}_grid_thw', 0) |
| 99 | + if grid_thw is not None: |
| 100 | + res[f'{media_type}_grid_thw'] = grid_thw |
| 101 | + return res |
| 102 | + |
| 103 | + |
| 104 | +# Register the Keye VL template |
| 105 | +register_template(KeyeTemplateMeta(MLLMTemplateType.keye_vl, template_cls=KeyeVLTemplate)) |
0 commit comments