- 
                Notifications
    You must be signed in to change notification settings 
- Fork 923
[model] support LLaVA-OneVision-1.5 #6284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c4a0133
              662c6ff
              0236c2d
              0dd6511
              dfbed96
              7be980c
              3b60a9d
              c3a2628
              05dbcce
              f72659e
              10edf1e
              1367fe1
              5943294
              2f72847
              1e99d19
              988cb7c
              c5a2a9f
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -5,6 +5,7 @@ | |||||||||
| from typing import Any, Dict | ||||||||||
|  | ||||||||||
| from transformers import AutoConfig | ||||||||||
| from transformers.dynamic_module_utils import get_class_from_dynamic_module | ||||||||||
|  | ||||||||||
| from swift.llm import TemplateType | ||||||||||
| from ..constant import MLLMModelType | ||||||||||
|  | @@ -389,3 +390,32 @@ def _new_forward(*args, **kwargs): | |||||||||
| requires=['transformers>=4.42', 'av'], | ||||||||||
| tags=['vision'], | ||||||||||
| model_arch=None)) | ||||||||||
|  | ||||||||||
|  | ||||||||||
| def get_model_tokenizer_llava_onevision1_5(model_dir, *args, **kwargs): | ||||||||||
| model_cls = get_class_from_dynamic_module('modeling_llavaonevision1_5.LLaVAOneVision1_5_ForConditionalGeneration', | ||||||||||
| model_dir) | ||||||||||
| model_cls._no_split_modules = ['LLaVAOneVision1_5_DecoderLayer', 'RiceBlock'] | ||||||||||
| model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs) | ||||||||||
| model.config.vision_start_token_id = 151652 | ||||||||||
| return model, processor | ||||||||||
|  | ||||||||||
|  | ||||||||||
| register_model( | ||||||||||
| ModelMeta( | ||||||||||
| MLLMModelType.llava_onevision1_5, | ||||||||||
| [ | ||||||||||
| ModelGroup([ | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-4B-Instruct', 'lmms-lab/LLaVA-OneVision-1.5-4B-Instruct'), | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-8B-Instruct', 'lmms-lab/LLaVA-OneVision-1.5-8B-Instruct'), | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-4B-Base', 'lmms-lab/LLaVA-OneVision-1.5-4B-Base'), | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-8B-Base', 'lmms-lab/LLaVA-OneVision-1.5-8B-Base'), | ||||||||||
| ], ), | ||||||||||
| ], | ||||||||||
| TemplateType.llava_onevision1_5, | ||||||||||
| get_model_tokenizer_llava_onevision1_5, | ||||||||||
| architectures=['LLaVAOneVision1_5_ForConditionalGeneration'], | ||||||||||
| model_arch=ModelArch.llava_onevision1_5, | ||||||||||
| requires=['transformers>=4.53.0', 'qwen_vl_utils'], | ||||||||||
| tags=['vision'], | ||||||||||
| 
      Comment on lines
    
      +419
     to 
      +420
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The  
        Suggested change
       
 | ||||||||||
| )) | ||||||||||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -6,6 +6,7 @@ | |
| import transformers | ||
| from packaging import version | ||
|  | ||
| from swift.utils import get_env_args | ||
| from ..base import Template | ||
| from ..constant import MLLMTemplateType | ||
| from ..register import TemplateMeta, register_template | ||
|  | @@ -307,3 +308,101 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in | |
| )) | ||
|  | ||
| register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen, template_cls=LLavaTemplate)) | ||
|  | ||
|  | ||
| class LLavaOneVision1_5Template(Template): | ||
| image_token_id = 151655 | ||
| video_token_id = 151656 | ||
| 
      Comment on lines
    
      +314
     to 
      +315
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. | ||
| placeholder_tokens = ['<|image_pad|>', '<|video_pad|>'] | ||
| use_model = True | ||
| support_padding_free = True | ||
|  | ||
| def init_env_args(self): | ||
| super().init_env_args() | ||
| self.bbox_format = get_env_args('QWENVL_BBOX_FORMAT', str, 'legacy') | ||
|  | ||
| def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, | ||
| inputs: StdTemplateInputs) -> List[Context]: | ||
| from qwen_vl_utils import fetch_image, fetch_video | ||
| assert media_type in {'image', 'video'} | ||
| if media_type == 'image': | ||
| inputs.images[index] = fetch_image({'image': inputs.images[index]}) | ||
| if self.mode == 'lmdeploy': | ||
| return ['<|vision_start|>', [-100], '<|vision_end|>'] | ||
| else: | ||
| return ['<|vision_start|><|image_pad|><|vision_end|>'] | ||
| else: | ||
| video = inputs.videos[index] | ||
| video, video_kwargs = fetch_video({'video': video}, return_video_sample_fps=True) | ||
| inputs.mm_processor_kwargs.setdefault('fps', []).append(video_kwargs) | ||
| tokens = ['<|vision_start|><|video_pad|><|vision_end|>'] | ||
| if isinstance(video, torch.Tensor): | ||
| video = video.to(torch.uint8) | ||
| inputs.videos[index] = video | ||
| return tokens | ||
|  | ||
| def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: | ||
| if self.bbox_format == 'legacy': | ||
| return [f'<|object_ref_start|>{ref}<|object_ref_end|>'] | ||
| else: | ||
| return [ref] | ||
|  | ||
| def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: | ||
| if self.bbox_format == 'legacy': | ||
| return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>'] | ||
| else: | ||
| return [str(bbox)] | ||
|  | ||
| def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||
| encoded = super()._encode(inputs) | ||
| processor = self.processor | ||
| input_ids = encoded['input_ids'] | ||
| labels = encoded['labels'] | ||
| loss_scale = encoded.get('loss_scale', None) | ||
| for media_type in ['images', 'videos']: | ||
| mm_data = getattr(inputs, media_type) | ||
| if mm_data: | ||
| if media_type == 'images': | ||
| media_token = self.image_token_id | ||
| media_inputs = processor.image_processor(images=mm_data, return_tensors='pt', do_resize=False) | ||
| media_grid_thw = media_inputs['image_grid_thw'] | ||
| else: | ||
| kwargs = {} | ||
| if hasattr(processor, 'video_processor'): | ||
| processor_func = processor.video_processor | ||
| else: | ||
| processor_func = processor.image_processor | ||
| kwargs['images'] = None | ||
| 
      Comment on lines
    
      +371
     to 
      +375
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback to  | ||
| media_inputs = processor_func(videos=mm_data, return_tensors='pt', do_resize=False, **kwargs) | ||
| media_grid_thw = media_inputs['video_grid_thw'] | ||
| media_token = self.video_token_id | ||
| idx_list = findall(input_ids, media_token) | ||
| merge_length = processor.image_processor.merge_size**2 | ||
|  | ||
| def _get_new_tokens(i): | ||
| token_len = (media_grid_thw[i].prod() // merge_length) | ||
| return [media_token] * token_len | ||
|  | ||
| input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list, | ||
| _get_new_tokens) | ||
| encoded.update(media_inputs) | ||
|  | ||
| encoded['input_ids'] = input_ids | ||
| encoded['labels'] = labels | ||
| encoded['loss_scale'] = loss_scale | ||
| return encoded | ||
|  | ||
| def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||
| if not self.is_training: | ||
| return inputs | ||
| input_ids = inputs['input_ids'] | ||
| base_model = self.get_base_model(model) | ||
| if hasattr(base_model.model, 'embed_tokens'): | ||
| inputs_embeds = base_model.model.embed_tokens(input_ids) | ||
| else: | ||
| inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) | ||
| inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config) | ||
| return {'inputs_embeds': inputs_embeds} | ||
|  | ||
|  | ||
| register_template(QwenTemplateMeta(MLLMTemplateType.llava_onevision1_5, template_cls=LLavaOneVision1_5Template)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value
151652is a magic number. To improve code readability and maintainability, it's better to define it as a constant with a descriptive name (e.g.,LLAVA_ONEVISION_VISION_START_TOKEN_ID) at the module level and use the constant here.