Skip to content

Commit 7919369

Browse files
authored
[compat] compat transformers main branch (v5) (#7895)
1 parent b0b9ee4 commit 7919369

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

swift/megatron/model/mm_gpts/qwen3_vl.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config):
6565
media_inputs = processor.image_processor(images=images, return_tensors='pt')
6666
media_inputs = to_device(media_inputs, input_ids.device)
6767
pixel_values = media_inputs['pixel_values'].type(dtype)
68-
image_embeds, deepstack_visual_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
68+
visual_res = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
69+
if hasattr(visual_res, 'pooler_output'):
70+
image_embeds = visual_res.pooler_output
71+
deepstack_visual_embeds = visual_res.deepstack_features
72+
else:
73+
image_embeds, deepstack_visual_embeds = visual_res
6974
deepstack_visual_embeds = torch.stack(deepstack_visual_embeds, dim=0)
7075
inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0.
7176
visual_pos_masks = None
@@ -80,7 +85,12 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config):
8085
pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
8186
grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
8287
pixel_values_mixed = pixel_values_mixed.type(dtype)
83-
mixed_embeds, deepstack_visual_embeds = visual(pixel_values_mixed, grid_thw=grid_thw)
88+
visual_res = visual(pixel_values_mixed, grid_thw=grid_thw)
89+
if hasattr(visual_res, 'pooler_output'):
90+
mixed_embeds = visual_res.pooler_output
91+
deepstack_visual_embeds = visual_res.deepstack_features
92+
else:
93+
mixed_embeds, deepstack_visual_embeds = visual_res
8494
if pixel_values is None:
8595
image_embeds = None
8696
video_embeds = mixed_embeds

swift/model/models/qwen.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,12 @@ def _forward_qwen3_vl_or_qwen3_omni(
859859
media_inputs = processor.image_processor(images=images, return_tensors='pt')
860860
media_inputs = to_device(media_inputs, input_ids.device)
861861
pixel_values = media_inputs['pixel_values'].type(dtype)
862-
image_embeds, deepstack_visual_embeds = self.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
862+
visual_res = self.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
863+
if hasattr(visual_res, 'pooler_output'):
864+
image_embeds = visual_res.pooler_output
865+
deepstack_visual_embeds = visual_res.deepstack_features
866+
else:
867+
image_embeds, deepstack_visual_embeds = visual_res
863868
inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0.
864869
visual_pos_masks = None
865870
else:
@@ -873,7 +878,12 @@ def _forward_qwen3_vl_or_qwen3_omni(
873878
pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
874879
grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
875880
pixel_values_mixed = pixel_values_mixed.type(dtype)
876-
mixed_embeds, deepstack_visual_embeds = self.visual(pixel_values_mixed, grid_thw=grid_thw)
881+
visual_res = self.visual(pixel_values_mixed, grid_thw=grid_thw)
882+
if hasattr(visual_res, 'pooler_output'):
883+
mixed_embeds = visual_res.pooler_output
884+
deepstack_visual_embeds = visual_res.deepstack_features
885+
else:
886+
mixed_embeds, deepstack_visual_embeds = visual_res
877887
if pixel_values is None:
878888
image_embeds = None
879889
video_embeds = mixed_embeds

swift/model/register.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ def get_model(self, model_dir: str, config: PretrainedConfig, processor: Process
308308
patch_output_normalizer(model, model_meta=model_meta)
309309
elif model_info.task_type == 'generative_reranker':
310310
self._patch_generative_reranker(model, processor)
311+
if version.parse(transformers.__version__) >= version.parse('5.0.0.dev'):
312+
self._compat_transformers5(model)
311313
return model
312314

313315
def _patch_generative_reranker(self, model, processor):
@@ -328,17 +330,16 @@ def _postprocess_model(self, model_dir, model):
328330
if self.leaf_modules is not None or model_info.is_moe_model:
329331
# deepspeed zero3
330332
self._deepspeed_set_z3_leaf_modules(model, self.leaf_modules)
331-
if version.parse(transformers.__version__) >= version.parse('5.0.0.dev'):
332-
self._compat_transformers5(model)
333333
model.model_info = self.model_info
334334
model.model_meta = self.model_meta
335335
model.model_dir = model_dir
336336
self._init_generation_config(model, model_dir)
337337
HfConfigFactory.set_model_config_attr(model, 'pad_token_id', self.pad_token)
338338

339-
def _add_new_special_tokens(self, model, tokenizer):
339+
def _add_new_special_tokens(self, model, processor):
340340
if not self.new_special_tokens:
341341
return
342+
tokenizer = self._get_tokenizer(processor)
342343
num_new_tokens = tokenizer.add_special_tokens({'additional_special_tokens': self.new_special_tokens})
343344
if num_new_tokens > 0:
344345
logger.info(f'Added {num_new_tokens} new special tokens.')
@@ -414,7 +415,7 @@ def _deepspeed_set_z3_leaf_modules(self, model, z3_leaf_modules):
414415
elif hf_model_type == 'qwen3_next':
415416
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
416417
z3_leaf_modules = [Qwen3NextSparseMoeBlock]
417-
elif architecture == 'OlmoeForCausalLM':
418+
elif hf_model_type == 'olmoe':
418419
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
419420
z3_leaf_modules = [OlmoeSparseMoeBlock]
420421

swift/template/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,6 +2112,8 @@ def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
21122112
media_inputs = to_device(media_inputs, input_ids.device)
21132113
pixel_values = media_inputs['pixel_values'].type(dtype)
21142114
image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
2115+
if hasattr(image_embeds, 'pooler_output'):
2116+
image_embeds = image_embeds.pooler_output
21152117
inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0.
21162118
else:
21172119
if pixel_values is None:
@@ -2125,6 +2127,8 @@ def _get_inputs_embeds_hf(inputs_embeds, inputs, visual, processor, config):
21252127
grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
21262128
pixel_values_mixed = pixel_values_mixed.type(dtype)
21272129
mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw)
2130+
if hasattr(mixed_embeds, 'pooler_output'):
2131+
mixed_embeds = mixed_embeds.pooler_output
21282132
if pixel_values is None:
21292133
image_embeds = None
21302134
video_embeds = mixed_embeds

0 commit comments

Comments
 (0)