Skip to content

Commit 8077330

Browse files
authored
compat transformers==4.52 (vlm) (#4738)
1 parent 59ecbc8 commit 8077330

File tree

4 files changed

+73
-43
lines changed

4 files changed

+73
-43
lines changed

swift/llm/model/model/gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,6 @@ def get_model_tokenizer_gemma3_vision(model_dir: str,
160160
TemplateType.gemma3_vision,
161161
get_model_tokenizer_gemma3_vision,
162162
architectures=['Gemma3ForConditionalGeneration'],
163-
model_arch=ModelArch.gemma3_vision,
163+
model_arch=ModelArch.llava_hf,
164164
requires=['transformers>=4.49'],
165165
))

swift/llm/model/model/stepfun.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_model_tokenizer_got_ocr2(*args, **kwargs):
3535

3636
def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs):
3737
from transformers.models.got_ocr2 import GotOcr2ForConditionalGeneration
38-
GotOcr2ForConditionalGeneration._no_split_modules.append('GotOcr2VisionLayer')
38+
GotOcr2ForConditionalGeneration._no_split_modules = ['GotOcr2VisionLayer']
3939
model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
4040
return model, processor
4141

@@ -49,7 +49,7 @@ def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs):
4949
],
5050
TemplateType.got_ocr2_hf,
5151
get_model_tokenizer_got_ocr2_hf,
52-
model_arch=ModelArch.got_ocr2_hf,
52+
model_arch=ModelArch.llava_hf,
5353
architectures=['GOTQwenForCausalLM'],
5454
tags=['vision']))
5555

swift/llm/model/model_arch.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
from dataclasses import dataclass, field
33
from typing import List, Optional, Union
44

5+
import transformers
6+
from packaging import version
7+
8+
transformers_ge_4_52 = version.parse(transformers.__version__) >= version.parse('4.52')
9+
510

611
class LLMModelArch:
712
qwen = 'qwen'
@@ -33,7 +38,6 @@ class MLLMModelArch:
3338

3439
llama3_1_omni = 'llama3_1_omni'
3540
llama3_2_vision = 'llama3_2_vision'
36-
llama4 = 'llama4'
3741

3842
llava_hf = 'llava_hf'
3943
llava_next_video_hf = 'llava_next_video_hf'
@@ -59,14 +63,12 @@ class MLLMModelArch:
5963
idefics3 = 'idefics3'
6064

6165
got_ocr2 = 'got_ocr2'
62-
got_ocr2_hf = 'got_ocr2_hf'
6366

6467
ovis1_6 = 'ovis1_6'
6568
molmo = 'molmo'
6669
emu3_chat = 'emu3_chat'
6770
megrez_omni = 'megrez_omni'
6871
valley = 'valley'
69-
gemma3_vision = 'gemma3_vision'
7072
mistral_2503 = 'mistral_2503'
7173

7274

@@ -308,13 +310,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
308310
lm_head='lm_head',
309311
))
310312

311-
register_model_arch(
312-
MultiModelKeys(
313-
MLLMModelArch.llava_hf,
314-
language_model='language_model',
315-
aligner='multi_modal_projector',
316-
vision_tower='vision_tower',
317-
))
313+
if transformers_ge_4_52:
314+
register_model_arch(
315+
MultiModelKeys(
316+
MLLMModelArch.llava_hf,
317+
language_model='model.language_model',
318+
aligner='model.multi_modal_projector',
319+
vision_tower='model.vision_tower',
320+
))
321+
else:
322+
register_model_arch(
323+
MultiModelKeys(
324+
MLLMModelArch.llava_hf,
325+
language_model='language_model',
326+
aligner='multi_modal_projector',
327+
vision_tower='vision_tower',
328+
))
318329

319330
register_model_arch(
320331
MultiModelKeys(
@@ -324,12 +335,20 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
324335
vision_tower='model.vision_tower',
325336
))
326337

327-
register_model_arch(
328-
MultiModelKeys(
329-
MLLMModelArch.llava_next_video_hf,
330-
language_model='language_model',
331-
aligner=['multi_modal_projector'],
332-
vision_tower='vision_tower'))
338+
if transformers_ge_4_52:
339+
register_model_arch(
340+
MultiModelKeys(
341+
MLLMModelArch.llava_next_video_hf,
342+
language_model='model.language_model',
343+
aligner=['model.multi_modal_projector'],
344+
vision_tower='model.vision_tower'))
345+
else:
346+
register_model_arch(
347+
MultiModelKeys(
348+
MLLMModelArch.llava_next_video_hf,
349+
language_model='language_model',
350+
aligner=['multi_modal_projector'],
351+
vision_tower='vision_tower'))
333352

334353
register_model_arch(
335354
MultiModelKeys(
@@ -459,13 +478,23 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
459478
vision_tower='audio_tower',
460479
))
461480

462-
register_model_arch(
463-
MultiModelKeys(
464-
MLLMModelArch.qwen2_vl,
465-
language_model='model',
466-
aligner='visual.merger',
467-
vision_tower='visual',
468-
))
481+
if transformers_ge_4_52:
482+
register_model_arch(
483+
MultiModelKeys(
484+
MLLMModelArch.qwen2_vl,
485+
language_model='model.language_model',
486+
aligner='model.visual.merger',
487+
vision_tower='model.visual',
488+
))
489+
else:
490+
register_model_arch(
491+
MultiModelKeys(
492+
MLLMModelArch.qwen2_vl,
493+
language_model='model',
494+
aligner='visual.merger',
495+
vision_tower='visual',
496+
))
497+
469498
register_model_arch(
470499
MultiModelKeys(
471500
MLLMModelArch.qwen2_5_omni,
@@ -507,13 +536,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
507536
vision_tower='model.vision_tower_high',
508537
))
509538

510-
register_model_arch(
511-
MultiModelKeys(
512-
MLLMModelArch.llama3_2_vision,
513-
language_model='language_model',
514-
aligner='multi_modal_projector',
515-
vision_tower='vision_model',
516-
))
539+
if transformers_ge_4_52:
540+
register_model_arch(
541+
MultiModelKeys(
542+
MLLMModelArch.llama3_2_vision,
543+
language_model='model.language_model',
544+
aligner='model.multi_modal_projector',
545+
vision_tower='model.vision_model',
546+
))
547+
else:
548+
register_model_arch(
549+
MultiModelKeys(
550+
MLLMModelArch.llama3_2_vision,
551+
language_model='language_model',
552+
aligner='multi_modal_projector',
553+
vision_tower='vision_model',
554+
))
517555

518556
register_model_arch(MultiModelKeys(
519557
MLLMModelArch.ovis1_6,
@@ -547,14 +585,6 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
547585
vision_tower=['model.vision_tower', 'model.qwen2vl_vision_tower'],
548586
))
549587

550-
register_model_arch(
551-
MultiModelKeys(
552-
MLLMModelArch.gemma3_vision,
553-
language_model='language_model',
554-
aligner='multi_modal_projector',
555-
vision_tower='vision_tower',
556-
))
557-
558588

559589
def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]:
560590
return MODEL_ARCH_MAPPING.get(arch_name)

swift/llm/model/patcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
1818

1919
from swift.llm import deep_getattr, to_device, to_float_dtype
20-
from swift.utils import get_dist_setting, get_logger, is_mp_ddp, safe_ddp_context, use_torchacc
20+
from swift.utils import get_dist_setting, get_logger, is_mp, is_mp_ddp, safe_ddp_context, use_torchacc
2121
from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count
2222
from .utils import HfConfigFactory, get_llm_model
2323

@@ -349,7 +349,7 @@ def new_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs):
349349

350350
@contextmanager
351351
def patch_tp_plan(load_model: bool):
352-
if not load_model or not is_mp_ddp() or version.parse(
352+
if not load_model or not is_mp() or version.parse(
353353
transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ:
354354
yield
355355
return

0 commit comments

Comments
 (0)