Skip to content

Commit f269f56

Browse files
authored
[FEATURE] Add Llava series to transformers (#1016)
* add llava-next test * add llava next test * add UT for llava models * add example * add mistral * add eager args * add support for mistral flash attention * add generate * fix cache * add generate * add generate for llava one vision * raise error for sdpa attention currently * align llava-next to v4.50.0 * align llava-next-video to 4.50.0 * align llava-onevision to v4.50.0 * clean * fix merge * clean generate example & add processor
1 parent 458c91d commit f269f56

26 files changed

+5741
-1
lines changed

mindone/transformers/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,25 @@
278278
)
279279
from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
280280
from .models.llava import LlavaConfig, LlavaForConditionalGeneration
281+
from .models.llava_next import (
282+
LlavaNextForConditionalGeneration,
283+
LlavaNextImageProcessor,
284+
LlavaNextPreTrainedModel,
285+
LlavaNextProcessor,
286+
)
287+
from .models.llava_next_video import (
288+
LlavaNextVideoForConditionalGeneration,
289+
LlavaNextVideoImageProcessor,
290+
LlavaNextVideoPreTrainedModel,
291+
LlavaNextVideoProcessor,
292+
)
293+
from .models.llava_onevision import (
294+
LlavaOnevisionForConditionalGeneration,
295+
LlavaOnevisionImageProcessor,
296+
LlavaOnevisionPreTrainedModel,
297+
LlavaOnevisionProcessor,
298+
LlavaOnevisionVideoProcessor,
299+
)
281300
from .models.m2m_100 import M2M100ForConditionalGeneration, M2M100Model, M2M100PreTrainedModel
282301
from .models.megatron_bert import (
283302
MegatronBertForCausalLM,

mindone/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def _from_config(cls, config, **kwargs):
946946

947947
if isinstance(mindspore_dtype, str):
948948
mindspore_dtype = getattr(ms, mindspore_dtype)
949-
else:
949+
elif mindspore_dtype is not None:
950950
TORCH_TO_MINDSPORE_DTYPE_MAP = {
951951
"torch.float32": ms.float32,
952952
"torch.bfloat16": ms.bfloat16,

mindone/transformers/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
levit,
5252
llama,
5353
llava,
54+
llava_next,
55+
llava_next_video,
56+
llava_onevision,
5457
m2m_100,
5558
megatron_bert,
5659
minicpm4,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@
7676
("persimmon", "PersimmonConfig"),
7777
("fuyu", "FuyuConfig"),
7878
("llava", "LlavaConfig"),
79+
("llava_next", "LlavaNextConfig"),
80+
("llava_next_video", "LlavaNextVideoConfig"),
81+
("llava_onevision", "LlavaOnevisionConfig"),
7982
("mistral", "MistralConfig"),
8083
("mobilebert", "MobileBertConfig"),
8184
("mpt", "MptConfig"),
@@ -162,6 +165,10 @@
162165
("llama2", "Llama2"),
163166
("llama3", "Llama3"),
164167
("llava", "Llava"),
168+
("llava_next", "LLaVA-NeXT"),
169+
("llava_next_video", "LLaVa-NeXT-Video"),
170+
("llava_onevision", "LLaVA-Onevision"),
171+
("mistral", "Mistral"),
165172
("persimmon", "Persimmon"),
166173
("fuyu", "Fuyu"),
167174
("mobilebert", "MobileBERT"),

mindone/transformers/models/auto/image_processing_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
("blip-2", ("BlipImageProcessor",)),
5353
("clip", ("CLIPImageProcessor",)),
5454
("dpt", ("DPTImageProcessor",)),
55+
("llava_next", ("LlavaNextImageProcessor",)),
56+
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
57+
("llava_onevision", ("LlavaOnevisionImageProcessor",)),
5558
]
5659
)
5760

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@
123123
("idefics2", "Idefics2ForConditionalGeneration"),
124124
("idefics3", "Idefics3ForConditionalGeneration"),
125125
("llava", "LlavaForConditionalGeneration"),
126+
("llava_next", "LlavaNextForConditionalGeneration"),
127+
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
128+
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
126129
("mobilebert", "MobileBertForPreTraining"),
127130
("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
128131
("roberta", "RobertaForMaskedLM"),
@@ -214,6 +217,7 @@
214217
("ijepa", "IJepaModel"),
215218
("imagegpt", "ImageGPTModel"),
216219
("levit", "LevitModel"),
220+
("siglip_vision_model", "SiglipVisionModel"),
217221
]
218222
)
219223

@@ -260,6 +264,9 @@
260264
("idefics2", "Idefics2ForConditionalGeneration"),
261265
("idefics3", "Idefics3ForConditionalGeneration"),
262266
("llava", "LlavaForConditionalGeneration"),
267+
("llava_next", "LlavaNextForConditionalGeneration"),
268+
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
269+
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
263270
("paligemma", "PaliGemmaForConditionalGeneration"),
264271
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
265272
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
@@ -278,6 +285,8 @@
278285
("idefics3", "Idefics3ForConditionalGeneration"),
279286
("fuyu", "FuyuForCausalLM"),
280287
("llava", "LlavaForConditionalGeneration"),
288+
("llava_next", "LlavaNextForConditionalGeneration"),
289+
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
281290
("paligemma", "PaliGemmaForConditionalGeneration"),
282291
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
283292
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
@@ -410,6 +419,7 @@
410419
("led", "LEDForQuestionAnswering"),
411420
("convbert", "ConvBertForQuestionAnswering"),
412421
("llama", "LlamaForQuestionAnswering"),
422+
("mistral", "MistralForQuestionAnswering"),
413423
("mobilebert", "MobileBertForQuestionAnswering"),
414424
("megatron-bert", "MegatronBertForQuestionAnswering"),
415425
("mistral", "MistralForQuestionAnswering"),
@@ -529,6 +539,7 @@
529539

530540
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
531541
[
542+
# Model for Zero Shot Image Classification mapping
532543
("blip", "BlipModel"),
533544
("siglip", "SiglipModel"),
534545
]

mindone/transformers/models/auto/processing_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
PROCESSOR_MAPPING_NAMES = OrderedDict(
5151
[
5252
("blip", "BlipProcessor"),
53+
("llava_next", "LlavaNextProcessor"),
54+
("llava_next_video", "LlavaNextVideoProcessor"),
55+
("llava_onevision", "LlavaOnevisionProcessor"),
5356
]
5457
)
5558

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .image_processing_llava_next import *
2+
from .modeling_llava_next import *
3+
from .processing_llava_next import *

0 commit comments

Comments
 (0)