Skip to content

Commit 02d3ead

Browse files
committed
Merge branch 'main' into v2.1
2 parents 5e29b90 + e3f0f74 commit 02d3ead

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

swift/llm/utils/model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,7 +3054,6 @@ def _new_func(*args, **kwargs):
30543054

30553055

30563056
def _patch_deepseek_vl(model) -> None:
3057-
30583057
if not hasattr(model, 'hf_device_map') or len(model.hf_device_map.values()) == 1:
30593058
return
30603059
if hasattr(model.language_model, '__old_forward'):
@@ -3078,11 +3077,6 @@ def _new_forward(*args, **kwargs) -> Tensor:
30783077
model.language_model.forward = _new_forward
30793078
model.language_model.__old_forward = __old_forward
30803079

3081-
model.prepare_inputs_embeds = MethodType(__prepare_inputs_embeds, model)
3082-
func_list = ['generate', 'get_input_embeddings', 'gradient_checkpointing_enable', 'forward']
3083-
_use_submodel_func(model, 'language_model', func_list)
3084-
model.generation_config = model.language_model.generation_config
3085-
30863080

30873081
@register_model(
30883082
ModelType.deepseek_vl_7b_chat,
@@ -3134,6 +3128,10 @@ def get_model_tokenizer_deepseek_vl(model_dir: str,
31343128
tokenizer.processor = processor
31353129
if load_model:
31363130
_patch_deepseek_vl(model)
3131+
model.prepare_inputs_embeds = MethodType(__prepare_inputs_embeds, model)
3132+
func_list = ['generate', 'get_input_embeddings', 'gradient_checkpointing_enable', 'forward']
3133+
_use_submodel_func(model, 'language_model', func_list)
3134+
model.generation_config = model.language_model.generation_config
31373135
return model, tokenizer
31383136

31393137

0 commit comments

Comments
 (0)