|
22 | 22 |
|
23 | 23 | from swift.utils import get_dist_setting, get_logger, is_mp, is_unsloth_available, patch_getattr |
24 | 24 | from .constant import ModelType |
25 | | -from .patcher import (patch_automodel, patch_automodel_for_sequence_classification, patch_get_dynamic_module, |
26 | | - patch_mp_ddp, patch_tp_plan) |
| 25 | +from .patcher import (get_lm_head_model, patch_automodel, patch_automodel_for_sequence_classification, |
| 26 | + patch_get_dynamic_module, patch_mp_ddp, patch_tp_plan) |
27 | 27 | from .utils import AttnImpl, HfConfigFactory, InitModelStrategy, ModelInfo, safe_snapshot_download |
28 | 28 |
|
29 | 29 | GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]] |
@@ -665,11 +665,15 @@ def get_model_tokenizer( |
665 | 665 | num_new_tokens = tokenizer.add_special_tokens({'additional_special_tokens': new_special_tokens}) |
666 | 666 | if num_new_tokens > 0: |
667 | 667 | logger.info(f'Added {num_new_tokens} new special tokens.') |
668 | | - if model is not None and model.config.vocab_size < len(tokenizer): |
669 | | - vocab_size = math.ceil(len(tokenizer) / 128) * 128 |
670 | | - model.resize_token_embeddings(vocab_size) |
671 | | - # fix transformers==4.52.4 qwen2.5-vl |
672 | | - model.config.vocab_size = vocab_size |
| 668 | + |
| 669 | + if model is not None: |
| 670 | + llm_model = get_lm_head_model(model, model_meta) |
| 671 | + origin_vocab_size = HfConfigFactory.get_config_attr(llm_model.config, 'vocab_size') |
| 672 | + if origin_vocab_size < len(tokenizer): |
| 673 | + vocab_size = math.ceil(len(tokenizer) / 128) * 128 |
| 674 | + llm_model.resize_token_embeddings(vocab_size) |
| 675 | + # fix transformers==4.52.4 qwen2.5-vl |
| 676 | + HfConfigFactory.set_config_attr(llm_model.config, 'vocab_size', vocab_size) |
673 | 677 |
|
674 | 678 | problem_type = kwargs.get('problem_type') |
675 | 679 | if problem_type is None and model_info.num_labels == 1: |
|
0 commit comments