|
3 | 3 | from typing import Any, Dict, Optional, Tuple, Type
|
4 | 4 |
|
5 | 5 | import torch
|
6 |
| -from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase |
| 6 | +from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase |
7 | 7 | from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
8 | 8 | from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
9 | 9 |
|
@@ -40,15 +40,19 @@ def get_model_tokenizer_qwen(model_dir: str,
|
40 | 40 | use_flash_attn = AttnImpl.to_use_flash_attn(kwargs.pop('attn_impl', None), 'auto')
|
41 | 41 | model_config.use_flash_attn = use_flash_attn
|
42 | 42 | kwargs['model_config'] = model_config
|
| 43 | + tokenizer = kwargs.get('tokenizer') |
| 44 | + if tokenizer is None: |
| 45 | + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
| 46 | + if tokenizer.eos_token_id is None: |
| 47 | + tokenizer.eos_token_id = tokenizer.eod_id |
| 48 | + kwargs['tokenizer'] = tokenizer |
43 | 49 | model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
44 | 50 | try:
|
45 | 51 | # fix mp+ddp bug
|
46 | 52 | model.transformer.registered_causal_mask = model.transformer.registered_causal_mask.cuda()
|
47 | 53 | logger.info('registered_causal_mask to cuda')
|
48 | 54 | except AttributeError:
|
49 | 55 | pass
|
50 |
| - if tokenizer.eos_token_id is None: |
51 |
| - tokenizer.eos_token_id = tokenizer.eod_id |
52 | 56 | return model, tokenizer
|
53 | 57 |
|
54 | 58 |
|
|
0 commit comments