Skip to content

Commit 99d2eeb

Browse files
committed
[model] fix qwen eos_token (#4888)
1 parent 68aceef commit 99d2eeb

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

swift/llm/model/model/qwen.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Dict, Optional, Tuple, Type
44

55
import torch
6-
from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase
6+
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase
77
from transformers.dynamic_module_utils import get_class_from_dynamic_module
88
from transformers.models.auto.tokenization_auto import get_tokenizer_config
99

@@ -40,15 +40,19 @@ def get_model_tokenizer_qwen(model_dir: str,
4040
use_flash_attn = AttnImpl.to_use_flash_attn(kwargs.pop('attn_impl', None), 'auto')
4141
model_config.use_flash_attn = use_flash_attn
4242
kwargs['model_config'] = model_config
43+
tokenizer = kwargs.get('tokenizer')
44+
if tokenizer is None:
45+
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
46+
if tokenizer.eos_token_id is None:
47+
tokenizer.eos_token_id = tokenizer.eod_id
48+
kwargs['tokenizer'] = tokenizer
4349
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
4450
try:
4551
# fix mp+ddp bug
4652
model.transformer.registered_causal_mask = model.transformer.registered_causal_mask.cuda()
4753
logger.info('registered_causal_mask to cuda')
4854
except AttributeError:
4955
pass
50-
if tokenizer.eos_token_id is None:
51-
tokenizer.eos_token_id = tokenizer.eod_id
5256
return model, tokenizer
5357

5458

0 commit comments

Comments
 (0)