|
5 | 5 | import time |
6 | 6 | from contextlib import contextmanager |
7 | 7 | from copy import deepcopy |
| 8 | +from functools import wraps |
8 | 9 | from queue import Queue |
9 | 10 | from threading import Thread |
10 | 11 | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union |
|
16 | 17 | from lmdeploy.serve.async_engine import AsyncEngine |
17 | 18 | from lmdeploy.serve.vl_async_engine import VLAsyncEngine |
18 | 19 | from tqdm import tqdm |
19 | | -from transformers import AutoConfig, GenerationConfig |
| 20 | +from transformers import AutoConfig, AutoTokenizer, GenerationConfig |
20 | 21 |
|
21 | 22 | from swift.utils import get_logger |
22 | 23 | from .argument import InferArguments |
@@ -69,7 +70,16 @@ def get_lmdeploy_engine( |
69 | 70 | pipeline_kwargs['vision_config'] = vision_config |
70 | 71 | logger.info(f'vision_config: {vision_config}') |
71 | 72 |
|
| 73 | + _old_from_pretrained = AutoTokenizer.from_pretrained |
| 74 | + |
| 75 | + @wraps(_old_from_pretrained) |
| 76 | + def _from_pretrained(self, *args, **kwargs): |
| 77 | + return tokenizer |
| 78 | + |
| 79 | + AutoTokenizer.from_pretrained = _from_pretrained |
72 | 80 | lmdeploy_engine = pipeline(model_dir, backend_config=backend_config, **pipeline_kwargs) |
| 81 | + AutoTokenizer.from_pretrained = _old_from_pretrained # recover |
| 82 | + |
73 | 83 | lmdeploy_engine.model_dir = model_dir |
74 | 84 | lmdeploy_engine.model_type = model_type |
75 | 85 | lmdeploy_engine.is_multimodal = is_multimodal |
|
0 commit comments