|
8 | 8 | from modelscope import GenerationConfig, snapshot_download |
9 | 9 | from torch import dtype as Dtype |
10 | 10 | from tqdm import tqdm |
| 11 | +from transformers import PreTrainedTokenizerBase |
11 | 12 | from vllm import (AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine, |
12 | 13 | SamplingParams) |
13 | 14 |
|
|
20 | 21 | logger = get_logger() |
21 | 22 |
|
22 | 23 |
|
| 24 | +def _get_vllm_tokenizer(vllm_engine: LLMEngine) -> PreTrainedTokenizerBase: |
| 25 | + tokenizer = vllm_engine.tokenizer |
| 26 | + if not isinstance(tokenizer, PreTrainedTokenizerBase): |
| 27 | + tokenizer = tokenizer.tokenizer |
| 28 | + return tokenizer |
| 29 | + |
| 30 | + |
23 | 31 | def get_vllm_engine(model_type: str, |
24 | 32 | torch_dtype: Optional[Dtype] = None, |
25 | 33 | *, |
@@ -89,7 +97,11 @@ def get_vllm_engine(model_type: str, |
89 | 97 | llm_engine.engine_args = engine_args |
90 | 98 | llm_engine.model_dir = model_dir |
91 | 99 | llm_engine.model_type = model_type |
92 | | - llm_engine.tokenizer = tokenizer |
| 100 | + if isinstance(llm_engine.tokenizer, PreTrainedTokenizerBase): |
| 101 | + llm_engine.tokenizer = tokenizer |
| 102 | + else: |
| 103 | + # compatible with vllm==0.3.* |
| 104 | + llm_engine.tokenizer.tokenizer = tokenizer |
93 | 105 | generation_config_path = os.path.join(model_dir, 'generation_config.json') |
94 | 106 | if os.path.isfile(generation_config_path): |
95 | 107 | generation_config = GenerationConfig.from_pretrained(model_dir) |
@@ -330,7 +342,7 @@ def prepare_vllm_engine_template( |
330 | 342 | max_model_len=args.max_model_len, |
331 | 343 | use_async=use_async, |
332 | 344 | **kwargs) |
333 | | - tokenizer = llm_engine.tokenizer |
| 345 | + tokenizer = _get_vllm_tokenizer(llm_engine) |
334 | 346 | if use_async: |
335 | 347 | model_config = asyncio.run(llm_engine.get_model_config()) |
336 | 348 | llm_engine.model_config = model_config |
|
0 commit comments