|
3 | 3 | from typing import Any, Dict, Optional, Tuple, Type |
4 | 4 |
|
5 | 5 | import torch |
6 | | -from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase |
| 6 | +from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase |
7 | 7 | from transformers.dynamic_module_utils import get_class_from_dynamic_module |
8 | 8 | from transformers.models.auto.tokenization_auto import get_tokenizer_config |
9 | 9 |
|
@@ -40,15 +40,19 @@ def get_model_tokenizer_qwen(model_dir: str, |
40 | 40 | use_flash_attn = AttnImpl.to_use_flash_attn(kwargs.pop('attn_impl', None), 'auto') |
41 | 41 | model_config.use_flash_attn = use_flash_attn |
42 | 42 | kwargs['model_config'] = model_config |
| 43 | + tokenizer = kwargs.get('tokenizer') |
| 44 | + if tokenizer is None: |
| 45 | + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
| 46 | + if tokenizer.eos_token_id is None: |
| 47 | + tokenizer.eos_token_id = tokenizer.eod_id |
| 48 | + kwargs['tokenizer'] = tokenizer |
43 | 49 | model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) |
44 | 50 | try: |
45 | 51 | # fix mp+ddp bug |
46 | 52 | model.transformer.registered_causal_mask = model.transformer.registered_causal_mask.cuda() |
47 | 53 | logger.info('registered_causal_mask to cuda') |
48 | 54 | except AttributeError: |
49 | 55 | pass |
50 | | - if tokenizer.eos_token_id is None: |
51 | | - tokenizer.eos_token_id = tokenizer.eod_id |
52 | 56 | return model, tokenizer |
53 | 57 |
|
54 | 58 |
|
|
0 commit comments