Skip to content

Commit be809df

Browse files
authored
fix telechat(#3825)
Co-authored-by: hjh <[email protected]>
1 parent d143313 commit be809df

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

swift/llm/model/model/telechat.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

3+
from transformers import GenerationConfig
4+
35
from swift.llm import TemplateType
46
from ..constant import LLMModelType
57
from ..model_arch import ModelArch
@@ -8,9 +10,10 @@
810

911
def get_model_tokenizer_telechat(*args, **kwargs):
1012
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
11-
if model is not None:
12-
for k in ['bos_token_id', 'eos_token_id', 'pad_token_id', 'user_token_id', 'bot_token_id']:
13-
setattr(tokenizer, k, getattr(model.generation_config, k))
13+
model_dir = args[0]
14+
generation_config = GenerationConfig.from_pretrained(model_dir)
15+
for k in ['bos_token_id', 'eos_token_id', 'pad_token_id', 'user_token_id', 'bot_token_id']:
16+
setattr(tokenizer, k, getattr(generation_config, k))
1417
return model, tokenizer
1518

1619

0 commit comments

Comments
 (0)