Skip to content

Commit 0f72fdd

Browse files
committed
[bugfix] fix internvl new_special_tokens (#5401)
1 parent deec84f commit 0f72fdd

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

swift/llm/model/patcher.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,21 @@ def _check_imports(filename) -> List[str]:
150150
td.check_imports = _old_check_imports
151151

152152

153-
def get_lm_head_model(model, model_meta, lm_heads):
153+
def get_lm_head_model(model, model_meta=None, lm_heads=None):
154+
model_meta = model_meta or model.model_meta
155+
lm_heads = lm_heads or ['lm_head']
154156
llm_prefix_list = getattr(model_meta.model_arch, 'language_model', None)
155157
prefix_list = []
156158
if llm_prefix_list:
157159
prefix_list = llm_prefix_list[0].split('.')
158160

159-
origin_model = model
160161
current_model = model
161-
for prefix in [None] + prefix_list:
162-
if prefix:
163-
current_model = getattr(current_model, prefix)
162+
for prefix in prefix_list:
163+
current_model = getattr(current_model, prefix)
164164
for lm_head in lm_heads:
165165
if hasattr(current_model, lm_head):
166166
return current_model
167-
168-
raise ValueError(f'Cannot find the lm_head. model: {origin_model}')
167+
return model
169168

170169

171170
def _patch_sequence_classification(model, model_meta):

swift/llm/model/register.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
from swift.utils import get_dist_setting, get_logger, is_mp, is_unsloth_available, patch_getattr
2424
from .constant import ModelType
25-
from .patcher import (patch_automodel, patch_automodel_for_sequence_classification, patch_get_dynamic_module,
26-
patch_mp_ddp, patch_tp_plan)
25+
from .patcher import (get_lm_head_model, patch_automodel, patch_automodel_for_sequence_classification,
26+
patch_get_dynamic_module, patch_mp_ddp, patch_tp_plan)
2727
from .utils import AttnImpl, HfConfigFactory, InitModelStrategy, ModelInfo, safe_snapshot_download
2828

2929
GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]]
@@ -665,11 +665,15 @@ def get_model_tokenizer(
665665
num_new_tokens = tokenizer.add_special_tokens({'additional_special_tokens': new_special_tokens})
666666
if num_new_tokens > 0:
667667
logger.info(f'Added {num_new_tokens} new special tokens.')
668-
if model is not None and model.config.vocab_size < len(tokenizer):
669-
vocab_size = math.ceil(len(tokenizer) / 128) * 128
670-
model.resize_token_embeddings(vocab_size)
671-
# fix transformers==4.52.4 qwen2.5-vl
672-
model.config.vocab_size = vocab_size
668+
669+
if model is not None:
670+
llm_model = get_lm_head_model(model, model_meta)
671+
origin_vocab_size = HfConfigFactory.get_config_attr(llm_model.config, 'vocab_size')
672+
if origin_vocab_size < len(tokenizer):
673+
vocab_size = math.ceil(len(tokenizer) / 128) * 128
674+
llm_model.resize_token_embeddings(vocab_size)
675+
# fix transformers==4.52.4 qwen2.5-vl
676+
HfConfigFactory.set_config_attr(llm_model.config, 'vocab_size', vocab_size)
673677

674678
problem_type = kwargs.get('problem_type')
675679
if problem_type is None and model_info.num_labels == 1:

0 commit comments

Comments
 (0)