Skip to content

Commit 9ca8206

Browse files
committed
Merge branch 'main' into release/1.6
2 parents a3d8e7e + 2ea5db8 commit 9ca8206

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

swift/llm/utils/vllm_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from modelscope import GenerationConfig, snapshot_download
99
from torch import dtype as Dtype
1010
from tqdm import tqdm
11+
from transformers import PreTrainedTokenizerBase
1112
from vllm import (AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine,
1213
SamplingParams)
1314

@@ -20,6 +21,13 @@
2021
logger = get_logger()
2122

2223

24+
def _get_vllm_tokenizer(vllm_engine: LLMEngine) -> PreTrainedTokenizerBase:
25+
tokenizer = vllm_engine.tokenizer
26+
if not isinstance(tokenizer, PreTrainedTokenizerBase):
27+
tokenizer = tokenizer.tokenizer
28+
return tokenizer
29+
30+
2331
def get_vllm_engine(model_type: str,
2432
torch_dtype: Optional[Dtype] = None,
2533
*,
@@ -89,7 +97,11 @@ def get_vllm_engine(model_type: str,
8997
llm_engine.engine_args = engine_args
9098
llm_engine.model_dir = model_dir
9199
llm_engine.model_type = model_type
92-
llm_engine.tokenizer = tokenizer
100+
if isinstance(llm_engine.tokenizer, PreTrainedTokenizerBase):
101+
llm_engine.tokenizer = tokenizer
102+
else:
103+
# compatible with vllm==0.3.*
104+
llm_engine.tokenizer.tokenizer = tokenizer
93105
generation_config_path = os.path.join(model_dir, 'generation_config.json')
94106
if os.path.isfile(generation_config_path):
95107
generation_config = GenerationConfig.from_pretrained(model_dir)
@@ -330,7 +342,7 @@ def prepare_vllm_engine_template(
330342
max_model_len=args.max_model_len,
331343
use_async=use_async,
332344
**kwargs)
333-
tokenizer = llm_engine.tokenizer
345+
tokenizer = _get_vllm_tokenizer(llm_engine)
334346
if use_async:
335347
model_config = asyncio.run(llm_engine.get_model_config())
336348
llm_engine.model_config = model_config

tests/llm/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_self_cognition(self):
208208
return
209209
for dataset in [[], [DatasetName.alpaca_zh, DatasetName.alpaca_en]]:
210210
sft_args = SftArguments(
211-
model_type=ModelType.qwen_7b_chat,
211+
model_type=ModelType.qwen1half_1_8b_chat_int4,
212212
dataset=dataset, # no dataset
213213
train_dataset_sample=100,
214214
dtype='fp16',

0 commit comments

Comments
 (0)