22from dataclasses import dataclass , field
33from typing import List , Optional , Union
44
5+ import transformers
6+ from packaging import version
7+
8+ transformers_ge_4_52 = version .parse (transformers .__version__ ) >= version .parse ('4.52' )
9+
510
611class LLMModelArch :
712 qwen = 'qwen'
@@ -33,7 +38,6 @@ class MLLMModelArch:
3338
3439 llama3_1_omni = 'llama3_1_omni'
3540 llama3_2_vision = 'llama3_2_vision'
36- llama4 = 'llama4'
3741
3842 llava_hf = 'llava_hf'
3943 llava_next_video_hf = 'llava_next_video_hf'
@@ -59,14 +63,12 @@ class MLLMModelArch:
5963 idefics3 = 'idefics3'
6064
6165 got_ocr2 = 'got_ocr2'
62- got_ocr2_hf = 'got_ocr2_hf'
6366
6467 ovis1_6 = 'ovis1_6'
6568 molmo = 'molmo'
6669 emu3_chat = 'emu3_chat'
6770 megrez_omni = 'megrez_omni'
6871 valley = 'valley'
69- gemma3_vision = 'gemma3_vision'
7072 mistral_2503 = 'mistral_2503'
7173
7274
@@ -308,13 +310,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
308310 lm_head = 'lm_head' ,
309311 ))
310312
311- register_model_arch (
312- MultiModelKeys (
313- MLLMModelArch .llava_hf ,
314- language_model = 'language_model' ,
315- aligner = 'multi_modal_projector' ,
316- vision_tower = 'vision_tower' ,
317- ))
313+ if transformers_ge_4_52 :
314+ register_model_arch (
315+ MultiModelKeys (
316+ MLLMModelArch .llava_hf ,
317+ language_model = 'model.language_model' ,
318+ aligner = 'model.multi_modal_projector' ,
319+ vision_tower = 'model.vision_tower' ,
320+ ))
321+ else :
322+ register_model_arch (
323+ MultiModelKeys (
324+ MLLMModelArch .llava_hf ,
325+ language_model = 'language_model' ,
326+ aligner = 'multi_modal_projector' ,
327+ vision_tower = 'vision_tower' ,
328+ ))
318329
319330register_model_arch (
320331 MultiModelKeys (
@@ -324,12 +335,20 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
324335 vision_tower = 'model.vision_tower' ,
325336 ))
326337
327- register_model_arch (
328- MultiModelKeys (
329- MLLMModelArch .llava_next_video_hf ,
330- language_model = 'language_model' ,
331- aligner = ['multi_modal_projector' ],
332- vision_tower = 'vision_tower' ))
338+ if transformers_ge_4_52 :
339+ register_model_arch (
340+ MultiModelKeys (
341+ MLLMModelArch .llava_next_video_hf ,
342+ language_model = 'model.language_model' ,
343+ aligner = ['model.multi_modal_projector' ],
344+ vision_tower = 'model.vision_tower' ))
345+ else :
346+ register_model_arch (
347+ MultiModelKeys (
348+ MLLMModelArch .llava_next_video_hf ,
349+ language_model = 'language_model' ,
350+ aligner = ['multi_modal_projector' ],
351+ vision_tower = 'vision_tower' ))
333352
334353register_model_arch (
335354 MultiModelKeys (
@@ -459,13 +478,23 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
459478 vision_tower = 'audio_tower' ,
460479 ))
461480
462- register_model_arch (
463- MultiModelKeys (
464- MLLMModelArch .qwen2_vl ,
465- language_model = 'model' ,
466- aligner = 'visual.merger' ,
467- vision_tower = 'visual' ,
468- ))
481+ if transformers_ge_4_52 :
482+ register_model_arch (
483+ MultiModelKeys (
484+ MLLMModelArch .qwen2_vl ,
485+ language_model = 'model.language_model' ,
486+ aligner = 'model.visual.merger' ,
487+ vision_tower = 'model.visual' ,
488+ ))
489+ else :
490+ register_model_arch (
491+ MultiModelKeys (
492+ MLLMModelArch .qwen2_vl ,
493+ language_model = 'model' ,
494+ aligner = 'visual.merger' ,
495+ vision_tower = 'visual' ,
496+ ))
497+
469498register_model_arch (
470499 MultiModelKeys (
471500 MLLMModelArch .qwen2_5_omni ,
@@ -507,13 +536,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
507536 vision_tower = 'model.vision_tower_high' ,
508537 ))
509538
510- register_model_arch (
511- MultiModelKeys (
512- MLLMModelArch .llama3_2_vision ,
513- language_model = 'language_model' ,
514- aligner = 'multi_modal_projector' ,
515- vision_tower = 'vision_model' ,
516- ))
539+ if transformers_ge_4_52 :
540+ register_model_arch (
541+ MultiModelKeys (
542+ MLLMModelArch .llama3_2_vision ,
543+ language_model = 'model.language_model' ,
544+ aligner = 'model.multi_modal_projector' ,
545+ vision_tower = 'model.vision_model' ,
546+ ))
547+ else :
548+ register_model_arch (
549+ MultiModelKeys (
550+ MLLMModelArch .llama3_2_vision ,
551+ language_model = 'language_model' ,
552+ aligner = 'multi_modal_projector' ,
553+ vision_tower = 'vision_model' ,
554+ ))
517555
518556register_model_arch (MultiModelKeys (
519557 MLLMModelArch .ovis1_6 ,
@@ -547,14 +585,6 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
547585 vision_tower = ['model.vision_tower' , 'model.qwen2vl_vision_tower' ],
548586 ))
549587
550- register_model_arch (
551- MultiModelKeys (
552- MLLMModelArch .gemma3_vision ,
553- language_model = 'language_model' ,
554- aligner = 'multi_modal_projector' ,
555- vision_tower = 'vision_tower' ,
556- ))
557-
558588
559589def get_model_arch (arch_name : Optional [str ]) -> Optional [MultiModelKeys ]:
560590 return MODEL_ARCH_MAPPING .get (arch_name )
0 commit comments