Skip to content

Commit 96e08b2

Browse files
authored
fix vllm==0.4.3 (#1055)
1 parent b42326e commit 96e08b2

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

swift/llm/deploy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fastapi import FastAPI, Request
1212
from fastapi.responses import JSONResponse, StreamingResponse
1313
from modelscope import GenerationConfig
14+
from packaging import version
1415
from peft import PeftModel
1516

1617
from swift.utils import get_logger, get_main, seed_everything
@@ -162,7 +163,13 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
162163
break
163164
assert lora_request is not None
164165
generate_kwargs['lora_request'] = lora_request
165-
result_generator = llm_engine.generate(None, generation_config, request_id, input_ids, **generate_kwargs)
166+
167+
import vllm
168+
if version.parse(vllm.__version__) >= version.parse('0.4.3'):
169+
result_generator = llm_engine.generate({'prompt_token_ids': input_ids}, generation_config, request_id,
170+
**generate_kwargs)
171+
else:
172+
result_generator = llm_engine.generate(None, generation_config, request_id, input_ids, **generate_kwargs)
166173

167174
async def _generate_full():
168175
result = None

swift/llm/utils/preprocess.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,19 @@ class SwiftPreprocessor:
1515
def __call__(self, dataset: HfDataset) -> HfDataset:
1616
if 'history' in dataset.features:
1717
old_history = dataset['history']
18-
18+
has_history = False
1919
history: List[History] = []
20-
for old_h in tqdm(old_history):
21-
if isinstance(old_h, list):
22-
break
23-
h = None
24-
if old_h is not None:
20+
for h in tqdm(old_history):
21+
if isinstance(h, str):
2522
h = ast.literal_eval(old_h)
23+
elif h is None:
24+
h = []
25+
if len(h) > 0:
26+
has_history = True
2627
history.append(h)
27-
else:
28-
dataset = dataset.remove_columns(['history']).add_column('history', history)
28+
dataset = dataset.remove_columns(['history'])
29+
if has_history:
30+
dataset = dataset.add_column('history', history)
2931
return dataset
3032

3133

swift/llm/utils/vllm_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,10 @@ def inference_stream_vllm(llm_engine: LLMEngine,
266266
resp_list[i] = {'response': '', 'history': history}
267267
continue
268268
input_ids = inputs['input_ids']
269-
llm_engine.add_request(str(i), None, generation_config, input_ids, **add_request_kwargs)
269+
if version.parse(vllm.__version__) >= version.parse('0.4.3'):
270+
llm_engine.add_request(str(i), {'prompt_token_ids': input_ids}, generation_config, **add_request_kwargs)
271+
else:
272+
llm_engine.add_request(str(i), None, generation_config, input_ids, **add_request_kwargs)
270273

271274
print_idx_list = [[0] for _ in range(len(request_list))]
272275
prog_bar = tqdm(total=len(request_list), dynamic_ncols=True, disable=not use_tqdm)
@@ -353,7 +356,10 @@ def inference_vllm(llm_engine: LLMEngine,
353356
resp_list[i] = {'response': '', 'history': history}
354357
continue
355358
input_ids = inputs['input_ids']
356-
llm_engine.add_request(str(i), None, generation_config, input_ids, **add_request_kwargs)
359+
if version.parse(vllm.__version__) >= version.parse('0.4.3'):
360+
llm_engine.add_request(str(i), {'prompt_token_ids': input_ids}, generation_config, **add_request_kwargs)
361+
else:
362+
llm_engine.add_request(str(i), None, generation_config, input_ids, **add_request_kwargs)
357363

358364
if use_tqdm is True:
359365
assert verbose is False

0 commit comments

Comments
 (0)