-
Notifications
You must be signed in to change notification settings - Fork 923
[model] support LLaVA-OneVision-1.5 #6284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 16 commits
c4a0133
662c6ff
0236c2d
0dd6511
dfbed96
7be980c
3b60a9d
c3a2628
05dbcce
f72659e
10edf1e
1367fe1
5943294
2f72847
1e99d19
988cb7c
c5a2a9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |||||||||
| from typing import Any, Dict | ||||||||||
|
|
||||||||||
| from transformers import AutoConfig | ||||||||||
| from transformers.dynamic_module_utils import get_class_from_dynamic_module | ||||||||||
|
|
||||||||||
| from swift.llm import TemplateType | ||||||||||
| from ..constant import MLLMModelType | ||||||||||
|
|
@@ -389,3 +390,32 @@ def _new_forward(*args, **kwargs): | |||||||||
| requires=['transformers>=4.42', 'av'], | ||||||||||
| tags=['vision'], | ||||||||||
| model_arch=None)) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_model_tokenizer_llava_onevision1_5(model_dir, *args, **kwargs): | ||||||||||
| model_cls = get_class_from_dynamic_module('modeling_llavaonevision1_5.LLaVAOneVision1_5_ForConditionalGeneration', | ||||||||||
| model_dir) | ||||||||||
| model_cls._no_split_modules = ['LLaVAOneVision1_5_DecoderLayer', 'RiceBlock'] | ||||||||||
| model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs) | ||||||||||
| model.config.vision_start_token_id = 151652 | ||||||||||
| return model, processor | ||||||||||
|
|
||||||||||
|
|
||||||||||
| register_model( | ||||||||||
| ModelMeta( | ||||||||||
| MLLMModelType.llava_onevision1_5, | ||||||||||
| [ | ||||||||||
| ModelGroup([ | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-4B-Instruct', 'lmms-lab/LLaVA-OneVision-1.5-4B-Instruct'), | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-8B-Instruct', 'lmms-lab/LLaVA-OneVision-1.5-8B-Instruct'), | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-4B-Base', 'lmms-lab/LLaVA-OneVision-1.5-4B-Base'), | ||||||||||
| Model('lmms-lab/LLaVA-OneVision-1.5-8B-Base', 'lmms-lab/LLaVA-OneVision-1.5-8B-Base'), | ||||||||||
| ], ), | ||||||||||
| ], | ||||||||||
| TemplateType.llava_onevision1_5, | ||||||||||
| get_model_tokenizer_llava_onevision1_5, | ||||||||||
| architectures=['LLaVAOneVision1_5_ForConditionalGeneration'], | ||||||||||
| model_arch=ModelArch.llava_onevision1_5, | ||||||||||
| requires=['transformers>=4.53.0', 'qwen_vl_utils'], | ||||||||||
| tags=['vision'], | ||||||||||
|
Comment on lines
+419
to
+420
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||
| )) | ||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import transformers | ||
| from packaging import version | ||
|
|
||
| from swift.utils import get_env_args | ||
| from ..base import Template | ||
| from ..constant import MLLMTemplateType | ||
| from ..register import TemplateMeta, register_template | ||
|
|
@@ -307,3 +308,101 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in | |
| )) | ||
|
|
||
| register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen, template_cls=LLavaTemplate)) | ||
|
|
||
|
|
||
| class LLavaOneVision1_5Template(Template): | ||
| image_token_id = 151655 | ||
| video_token_id = 151656 | ||
|
Comment on lines
+314
to
+315
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| placeholder_tokens = ['<|image_pad|>', '<|video_pad|>'] | ||
| use_model = True | ||
| support_padding_free = True | ||
|
|
||
| def init_env_args(self): | ||
| super().init_env_args() | ||
| self.bbox_format = get_env_args('QWENVL_BBOX_FORMAT', str, 'legacy') | ||
|
|
||
| def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, | ||
| inputs: StdTemplateInputs) -> List[Context]: | ||
| from qwen_vl_utils import fetch_image, fetch_video | ||
| assert media_type in {'image', 'video'} | ||
| if media_type == 'image': | ||
| inputs.images[index] = fetch_image({'image': inputs.images[index]}) | ||
| if self.mode == 'lmdeploy': | ||
| return ['<|vision_start|>', [-100], '<|vision_end|>'] | ||
| else: | ||
| return ['<|vision_start|><|image_pad|><|vision_end|>'] | ||
| else: | ||
| video = inputs.videos[index] | ||
| video, video_kwargs = fetch_video({'video': video}, return_video_sample_fps=True) | ||
| inputs.mm_processor_kwargs.setdefault('fps', []).append(video_kwargs) | ||
| tokens = ['<|vision_start|><|video_pad|><|vision_end|>'] | ||
| if isinstance(video, torch.Tensor): | ||
| video = video.to(torch.uint8) | ||
| inputs.videos[index] = video | ||
| return tokens | ||
|
|
||
| def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: | ||
| if self.bbox_format == 'legacy': | ||
| return [f'<|object_ref_start|>{ref}<|object_ref_end|>'] | ||
| else: | ||
| return [ref] | ||
|
|
||
| def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: | ||
| if self.bbox_format == 'legacy': | ||
| return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>'] | ||
| else: | ||
| return [str(bbox)] | ||
|
|
||
| def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||
| encoded = super()._encode(inputs) | ||
| processor = self.processor | ||
| input_ids = encoded['input_ids'] | ||
| labels = encoded['labels'] | ||
| loss_scale = encoded.get('loss_scale', None) | ||
| for media_type in ['images', 'videos']: | ||
| mm_data = getattr(inputs, media_type) | ||
| if mm_data: | ||
| if media_type == 'images': | ||
| media_token = self.image_token_id | ||
| media_inputs = processor.image_processor(images=mm_data, return_tensors='pt', do_resize=False) | ||
| media_grid_thw = media_inputs['image_grid_thw'] | ||
| else: | ||
| kwargs = {} | ||
| if hasattr(processor, 'video_processor'): | ||
| processor_func = processor.video_processor | ||
| else: | ||
| processor_func = processor.image_processor | ||
| kwargs['images'] = None | ||
|
Comment on lines
+371
to
+375
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback to |
||
| media_inputs = processor_func(videos=mm_data, return_tensors='pt', do_resize=False, **kwargs) | ||
| media_grid_thw = media_inputs['video_grid_thw'] | ||
| media_token = self.video_token_id | ||
| idx_list = findall(input_ids, media_token) | ||
| merge_length = processor.image_processor.merge_size**2 | ||
|
|
||
| def _get_new_tokens(i): | ||
| token_len = (media_grid_thw[i].prod() // merge_length) | ||
| return [media_token] * token_len | ||
|
|
||
| input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list, | ||
| _get_new_tokens) | ||
| encoded.update(media_inputs) | ||
|
|
||
| encoded['input_ids'] = input_ids | ||
| encoded['labels'] = labels | ||
| encoded['loss_scale'] = loss_scale | ||
| return encoded | ||
|
|
||
| def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||
| if not self.is_training: | ||
| return inputs | ||
| input_ids = inputs['input_ids'] | ||
| base_model = self.get_base_model(model) | ||
| if hasattr(base_model.model, 'embed_tokens'): | ||
| inputs_embeds = base_model.model.embed_tokens(input_ids) | ||
| else: | ||
| inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) | ||
| inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config) | ||
| return {'inputs_embeds': inputs_embeds} | ||
|
|
||
|
|
||
| register_template(QwenTemplateMeta(MLLMTemplateType.llava_onevision1_5, template_cls=LLavaOneVision1_5Template)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value
151652is a magic number. To improve code readability and maintainability, it's better to define it as a constant with a descriptive name (e.g.,LLAVA_ONEVISION_VISION_START_TOKEN_ID) at the module level and use the constant here.