2
2
from dataclasses import dataclass , field
3
3
from typing import List , Optional , Union
4
4
5
+ import transformers
6
+ from packaging import version
7
+
8
+ transformers_ge_4_52 = version .parse (transformers .__version__ ) >= version .parse ('4.52' )
9
+
5
10
6
11
class LLMModelArch :
7
12
qwen = 'qwen'
@@ -33,7 +38,6 @@ class MLLMModelArch:
33
38
34
39
llama3_1_omni = 'llama3_1_omni'
35
40
llama3_2_vision = 'llama3_2_vision'
36
- llama4 = 'llama4'
37
41
38
42
llava_hf = 'llava_hf'
39
43
llava_next_video_hf = 'llava_next_video_hf'
@@ -59,14 +63,12 @@ class MLLMModelArch:
59
63
idefics3 = 'idefics3'
60
64
61
65
got_ocr2 = 'got_ocr2'
62
- got_ocr2_hf = 'got_ocr2_hf'
63
66
64
67
ovis1_6 = 'ovis1_6'
65
68
molmo = 'molmo'
66
69
emu3_chat = 'emu3_chat'
67
70
megrez_omni = 'megrez_omni'
68
71
valley = 'valley'
69
- gemma3_vision = 'gemma3_vision'
70
72
mistral_2503 = 'mistral_2503'
71
73
72
74
@@ -308,13 +310,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
308
310
lm_head = 'lm_head' ,
309
311
))
310
312
311
- register_model_arch (
312
- MultiModelKeys (
313
- MLLMModelArch .llava_hf ,
314
- language_model = 'language_model' ,
315
- aligner = 'multi_modal_projector' ,
316
- vision_tower = 'vision_tower' ,
317
- ))
313
+ if transformers_ge_4_52 :
314
+ register_model_arch (
315
+ MultiModelKeys (
316
+ MLLMModelArch .llava_hf ,
317
+ language_model = 'model.language_model' ,
318
+ aligner = 'model.multi_modal_projector' ,
319
+ vision_tower = 'model.vision_tower' ,
320
+ ))
321
+ else :
322
+ register_model_arch (
323
+ MultiModelKeys (
324
+ MLLMModelArch .llava_hf ,
325
+ language_model = 'language_model' ,
326
+ aligner = 'multi_modal_projector' ,
327
+ vision_tower = 'vision_tower' ,
328
+ ))
318
329
319
330
register_model_arch (
320
331
MultiModelKeys (
@@ -324,12 +335,20 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
324
335
vision_tower = 'model.vision_tower' ,
325
336
))
326
337
327
- register_model_arch (
328
- MultiModelKeys (
329
- MLLMModelArch .llava_next_video_hf ,
330
- language_model = 'language_model' ,
331
- aligner = ['multi_modal_projector' ],
332
- vision_tower = 'vision_tower' ))
338
+ if transformers_ge_4_52 :
339
+ register_model_arch (
340
+ MultiModelKeys (
341
+ MLLMModelArch .llava_next_video_hf ,
342
+ language_model = 'model.language_model' ,
343
+ aligner = ['model.multi_modal_projector' ],
344
+ vision_tower = 'model.vision_tower' ))
345
+ else :
346
+ register_model_arch (
347
+ MultiModelKeys (
348
+ MLLMModelArch .llava_next_video_hf ,
349
+ language_model = 'language_model' ,
350
+ aligner = ['multi_modal_projector' ],
351
+ vision_tower = 'vision_tower' ))
333
352
334
353
register_model_arch (
335
354
MultiModelKeys (
@@ -459,13 +478,23 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
459
478
vision_tower = 'audio_tower' ,
460
479
))
461
480
462
- register_model_arch (
463
- MultiModelKeys (
464
- MLLMModelArch .qwen2_vl ,
465
- language_model = 'model' ,
466
- aligner = 'visual.merger' ,
467
- vision_tower = 'visual' ,
468
- ))
481
+ if transformers_ge_4_52 :
482
+ register_model_arch (
483
+ MultiModelKeys (
484
+ MLLMModelArch .qwen2_vl ,
485
+ language_model = 'model.language_model' ,
486
+ aligner = 'model.visual.merger' ,
487
+ vision_tower = 'model.visual' ,
488
+ ))
489
+ else :
490
+ register_model_arch (
491
+ MultiModelKeys (
492
+ MLLMModelArch .qwen2_vl ,
493
+ language_model = 'model' ,
494
+ aligner = 'visual.merger' ,
495
+ vision_tower = 'visual' ,
496
+ ))
497
+
469
498
register_model_arch (
470
499
MultiModelKeys (
471
500
MLLMModelArch .qwen2_5_omni ,
@@ -507,13 +536,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
507
536
vision_tower = 'model.vision_tower_high' ,
508
537
))
509
538
510
- register_model_arch (
511
- MultiModelKeys (
512
- MLLMModelArch .llama3_2_vision ,
513
- language_model = 'language_model' ,
514
- aligner = 'multi_modal_projector' ,
515
- vision_tower = 'vision_model' ,
516
- ))
539
+ if transformers_ge_4_52 :
540
+ register_model_arch (
541
+ MultiModelKeys (
542
+ MLLMModelArch .llama3_2_vision ,
543
+ language_model = 'model.language_model' ,
544
+ aligner = 'model.multi_modal_projector' ,
545
+ vision_tower = 'model.vision_model' ,
546
+ ))
547
+ else :
548
+ register_model_arch (
549
+ MultiModelKeys (
550
+ MLLMModelArch .llama3_2_vision ,
551
+ language_model = 'language_model' ,
552
+ aligner = 'multi_modal_projector' ,
553
+ vision_tower = 'vision_model' ,
554
+ ))
517
555
518
556
register_model_arch (MultiModelKeys (
519
557
MLLMModelArch .ovis1_6 ,
@@ -547,14 +585,6 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
547
585
vision_tower = ['model.vision_tower' , 'model.qwen2vl_vision_tower' ],
548
586
))
549
587
550
- register_model_arch (
551
- MultiModelKeys (
552
- MLLMModelArch .gemma3_vision ,
553
- language_model = 'language_model' ,
554
- aligner = 'multi_modal_projector' ,
555
- vision_tower = 'vision_tower' ,
556
- ))
557
-
558
588
559
589
def get_model_arch (arch_name : Optional [str ]) -> Optional [MultiModelKeys ]:
560
590
return MODEL_ARCH_MAPPING .get (arch_name )
0 commit comments