Skip to content

Commit b5925b3

Browse files
authored
support lmdeploy & app-ui (#1546)
1 parent bf51950 commit b5925b3

File tree

6 files changed

+30
-24
lines changed

6 files changed

+30
-24
lines changed

requirements/framework.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
accelerate
22
addict
33
aiohttp
4+
attrdict
45
binpacking
56
dacite
67
datasets<2.19

requirements/llm.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
attrdict
21
charset_normalizer
32
cpm_kernels
43
fastapi

swift/llm/app_ui.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ def clear_session() -> History:
1515
def gradio_generation_demo(args: AppUIArguments) -> None:
1616
import gradio as gr
1717
if args.infer_backend == 'vllm':
18-
from swift.llm import prepare_vllm_engine_template, inference_stream_vllm
18+
from swift.llm import prepare_vllm_engine_template, inference_stream_vllm as inference_stream_x
1919
llm_engine, template = prepare_vllm_engine_template(args)
20+
elif args.infer_backend == 'lmdeploy':
21+
from swift.llm import prepare_lmdeploy_engine_template, inference_stream_lmdeploy as inference_stream_x
22+
llm_engine, template = prepare_lmdeploy_engine_template(args)
2023
else:
2124
model, template = prepare_model_template(args)
2225

2326
def model_generation(query: str) -> Iterator[str]:
24-
if args.infer_backend == 'vllm':
25-
gen = inference_stream_vllm(llm_engine, template, [{'query': query}])
27+
if args.infer_backend in {'vllm', 'lmdeploy'}:
28+
gen = inference_stream_x(llm_engine, template, [{'query': query}])
2629
for resp_list in gen:
2730
response = resp_list[0]['response']
2831
yield response
@@ -52,15 +55,18 @@ def model_generation(query: str) -> Iterator[str]:
5255
def gradio_chat_demo(args: AppUIArguments) -> None:
5356
import gradio as gr
5457
if args.infer_backend == 'vllm':
55-
from swift.llm import prepare_vllm_engine_template, inference_stream_vllm
58+
from swift.llm import prepare_vllm_engine_template, inference_stream_vllm as inference_stream_x
5659
llm_engine, template = prepare_vllm_engine_template(args)
60+
elif args.infer_backend == 'lmdeploy':
61+
from swift.llm import prepare_lmdeploy_engine_template, inference_stream_lmdeploy as inference_stream_x
62+
llm_engine, template = prepare_lmdeploy_engine_template(args)
5763
else:
5864
model, template = prepare_model_template(args)
5965

6066
def model_chat(query: str, history: History) -> Iterator[Tuple[str, History]]:
6167
old_history, history = limit_history_length(template, query, history, args.max_length)
62-
if args.infer_backend == 'vllm':
63-
gen = inference_stream_vllm(llm_engine, template, [{'query': query, 'history': history}])
68+
if args.infer_backend in {'vllm', 'lmdeploy'}:
69+
gen = inference_stream_x(llm_engine, template, [{'query': query, 'history': history}])
6470
for resp_list in gen:
6571
history = resp_list[0]['history']
6672
total_history = old_history + history

swift/llm/deploy.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
import time
66
from concurrent.futures import ThreadPoolExecutor
7-
from contextlib import nullcontext
87
from dataclasses import asdict
98
from http import HTTPStatus
109
from typing import List, Optional, Union
@@ -101,17 +100,10 @@ async def _prepare_request(request: Union[ChatCompletionRequest, CompletionReque
101100
if not is_valid:
102101
return create_error_response(HTTPStatus.BAD_REQUEST, 'API key error')
103102

104-
if _args.infer_backend == 'vllm':
105-
from .utils import vllm_context
106-
model_or_engine = llm_engine
107-
context = vllm_context(template)
108-
elif _args.infer_backend == 'lmdeploy':
109-
from .utils import lmdeploy_context
103+
if _args.infer_backend in {'vllm', 'lmdeploy'}:
110104
model_or_engine = llm_engine
111-
context = lmdeploy_context(template)
112105
else:
113106
model_or_engine = model
114-
context = nullcontext(template)
115107

116108
error_msg = await check_model(request)
117109
if error_msg is not None:
@@ -147,10 +139,9 @@ async def _prepare_request(request: Union[ChatCompletionRequest, CompletionReque
147139
example['tools'] = [tool]
148140
elif request.tool_choice == 'auto':
149141
example['tools'] = request.tools
150-
with context:
151-
executor = ThreadPoolExecutor(max_workers=1)
152-
loop = asyncio.get_running_loop()
153-
inputs = (await loop.run_in_executor(executor, template.encode, example))[0]
142+
executor = ThreadPoolExecutor(max_workers=1)
143+
loop = asyncio.get_running_loop()
144+
inputs = (await loop.run_in_executor(executor, template.encode, example))[0]
154145
request_id = f'chatcmpl-{random_uuid()}'
155146
_request['messages'] = messages
156147
else:
@@ -167,10 +158,9 @@ async def _prepare_request(request: Union[ChatCompletionRequest, CompletionReque
167158
example = {'query': prompt}
168159
if len(images) > 0:
169160
example['images'] = images
170-
with context:
171-
executor = ThreadPoolExecutor(max_workers=1)
172-
loop = asyncio.get_running_loop()
173-
inputs = (await loop.run_in_executor(executor, template.encode, example))[0]
161+
executor = ThreadPoolExecutor(max_workers=1)
162+
loop = asyncio.get_running_loop()
163+
inputs = (await loop.run_in_executor(executor, template.encode, example))[0]
174164
request_id = f'cmpl-{random_uuid()}'
175165
_request['prompt'] = prompt
176166

@@ -709,9 +699,11 @@ def llm_deploy(args: DeployArguments) -> None:
709699
if args.infer_backend == 'vllm':
710700
from .utils import prepare_vllm_engine_template
711701
llm_engine, template = prepare_vllm_engine_template(args, use_async=True)
702+
template._is_vllm = True
712703
elif args.infer_backend == 'lmdeploy':
713704
from .utils import prepare_lmdeploy_engine_template
714705
llm_engine, template = prepare_lmdeploy_engine_template(args)
706+
template._is_lmdeploy = True
715707
else:
716708
model, template = prepare_model_template(args)
717709
uvicorn.run(app, host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile)

swift/llm/utils/lmdeploy_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def inference_stream_lmdeploy(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine]
204204
generation_info: Optional[Dict[str, Any]] = None,
205205
use_tqdm: bool = False,
206206
**kwargs) -> List[Dict[str, Any]]:
207+
if len(request_list) == 0:
208+
return []
207209
start_runtime = time.perf_counter()
208210
if generation_config is None:
209211
generation_config = getattr(lmdeploy_engine, 'generation_config', LmdeployGenerationConfig())
@@ -292,6 +294,8 @@ def inference_lmdeploy(lmdeploy_engine: Union[AsyncEngine, VLAsyncEngine],
292294
prompt_prefix: str = '[PROMPT]',
293295
output_prefix: str = '[OUTPUT]',
294296
**kwargs) -> List[Dict[str, Any]]:
297+
if len(request_list) == 0:
298+
return []
295299
runtime = time.perf_counter()
296300
if generation_config is None:
297301
generation_config = getattr(lmdeploy_engine, 'generation_config', LmdeployGenerationConfig())

swift/llm/utils/vllm_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def inference_stream_vllm(
376376
return: e.g. [{'response': 'hi!', 'history': [('hello!', 'hi!')]}].
377377
The keys to be included will be: 'response', 'history'.
378378
"""
379+
if len(request_list) == 0:
380+
return []
379381
start_runtime = time.perf_counter()
380382
if generation_config is None:
381383
generation_config = getattr(llm_engine, 'generation_config', VllmGenerationConfig())
@@ -468,6 +470,8 @@ def inference_vllm(llm_engine: LLMEngine,
468470
return: e.g. [{'response': 'hi!', 'history': [('hello!', 'hi!')]}].
469471
The keys to be included will be: 'response', 'history'.
470472
"""
473+
if len(request_list) == 0:
474+
return []
471475
runtime = time.perf_counter()
472476
if generation_config is None:
473477
generation_config = getattr(llm_engine, 'generation_config', VllmGenerationConfig())

0 commit comments

Comments
 (0)