Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions swift/infer_engine/infer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ def _parse_stream_data(data: bytes) -> Optional[str]:
data = data.strip()
if len(data) == 0:
return
assert data.startswith('data:'), f'data: {data}'
return data[5:].strip()
if data.startswith('data:'):
return data[5:].strip()
return data

async def infer_async(
self,
Expand All @@ -138,6 +139,15 @@ async def infer_async(
async def _gen_stream() -> AsyncIterator[ChatCompletionStreamResponse]:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp:
if resp.status >= 400 or resp.content_type != 'text/event-stream':
data = await resp.text()
try:
resp_obj = json.loads(data)
except json.JSONDecodeError:
raise HTTPError(data)
if resp_obj.get('object') == 'error':
raise HTTPError(resp_obj['message'])
raise HTTPError(data)
async for data in resp.content:
data = self._parse_stream_data(data)
if data == '[DONE]':
Expand Down
6 changes: 3 additions & 3 deletions swift/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ async def _run_async_iter():
async for item in await async_iter:
queue.put(item)
except Exception as e:
if getattr(self, 'strict', True):
raise
queue.put(e)
else:
finally:
queue.put(None)

try:
Expand All @@ -103,6 +101,8 @@ async def _run_async_iter():
if output is None or isinstance(output, Exception):
prog_bar.update()
self._update_metrics(pre_output, metrics)
if isinstance(output, Exception) and getattr(self, 'strict', True):
raise output
return
pre_output = output
yield output
Expand Down
34 changes: 32 additions & 2 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class LLMInfer(BaseUI):
'en': 'Port'
},
},
'api_key': {
'label': {
'zh': '接口token',
'en': 'API key'
},
'info': {
'zh': '部署服务使用的API key,聊天时会自动复用',
'en': 'API key used by the deployed service and reused for chat requests'
}
},
'llm_infer': {
'label': {
'zh': 'LLM推理',
Expand Down Expand Up @@ -140,6 +150,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
scale=8)
infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4)
gr.Textbox(elem_id='port', lines=1, value='8000', scale=4)
gr.Textbox(elem_id='api_key', lines=1, scale=6)
chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height')
with gr.Row(equal_height=True):
prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True)
Expand Down Expand Up @@ -388,12 +399,16 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au
infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + prompt

_, args = Runtime.parse_info_from_cmdline(running_task)
if 'port' not in args:
raise gr.Error('Please select a valid running deployment first.')
request_config = RequestConfig(
temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
request_config.stream = True
request_config.stop = ['Observation:']
request_config.max_tokens = max_new_tokens
stream_resp_with_history = ''
stream_reasoning_content = ''
stream_response_content = ''
response = ''
i = len(infer_request.messages) - 1
for i in range(len(infer_request.messages) - 1, -1, -1):
Expand All @@ -412,14 +427,29 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au
if infer_model_type:
model_kwargs = {'model': infer_model_type}
gen_list = InferClient(
port=args['port'], ).infer(
port=args['port'],
api_key=args.get('api_key', 'EMPTY'),
).infer(
infer_requests=[_infer_request], request_config=request_config, **model_kwargs)
if infer_request.messages[-1]['role'] != 'assistant':
infer_request.messages.append({'role': 'assistant', 'content': ''})
for chunk in gen_list[0]:
if chunk is None:
continue
stream_resp_with_history += chunk.choices[0].delta.content if chat else chunk.choices[0].text
if chat:
delta = chunk.choices[0].delta
if delta.reasoning_content:
stream_reasoning_content += delta.reasoning_content
if delta.content:
stream_response_content += delta.content
if stream_reasoning_content and stream_response_content:
stream_resp_with_history = f'<think>\n{stream_reasoning_content}</think>\n{stream_response_content}'
elif stream_reasoning_content:
stream_resp_with_history = f'<think>\n{stream_reasoning_content}'
else:
stream_resp_with_history = stream_response_content
else:
stream_resp_with_history += chunk.choices[0].text
infer_request.messages[-1]['content'] = stream_resp_with_history
chatbot_content = cls._replace_tag_with_media(infer_request)
chatbot_content = cls.parse_text(chatbot_content)
Expand Down
70 changes: 54 additions & 16 deletions swift/ui/llm_infer/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os.path
import psutil
import re
import shlex
import subprocess
import sys
import time
Expand Down Expand Up @@ -122,8 +123,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
def break_log_event(cls, task):
if not task:
return
pid, all_args = cls.parse_info_from_cmdline(task)
cls.log_event[all_args['log_file']] = True
_, all_args = cls.parse_info_from_cmdline(task)
log_file = all_args.get('log_file')
if not log_file:
return
cls.log_event[log_file] = True

@classmethod
def update_log(cls):
Expand All @@ -134,7 +138,9 @@ def wait(cls, task):
if not task:
return [None]
_, args = cls.parse_info_from_cmdline(task)
log_file = args['log_file']
log_file = args.get('log_file')
if not log_file:
return [None]
cls.log_event[log_file] = False
offset = 0
latest_data = ''
Expand Down Expand Up @@ -230,29 +236,61 @@ def construct_running_task(proc):
@classmethod
def parse_info_from_cmdline(cls, task):
pid = None
for i in range(3):
slash = task.find('/')
if i == 0:
pid = task[:slash].split(':')[1]
task = task[slash + 1:]
args = task.split(f'swift {cls.cmd}')[1]
args = [arg.strip() for arg in args.split('--') if arg.strip()]
if not isinstance(task, str) or not task:
return pid, {}

pid_match = re.search(r'(?:^|/)pid:(\d+)', task)
if pid_match:
pid = pid_match.group(1)

cmdline = task.split('/cmd:', 1)[1] if '/cmd:' in task else task
args = None
if f'swift {cls.cmd}' in cmdline:
args = cmdline.split(f'swift {cls.cmd}', 1)[1]
else:
deploy_match = re.search(rf'\S*{re.escape(cls.cmd)}\.py(?=\s|$)', cmdline)
if deploy_match:
args = cmdline[deploy_match.end():]
if args is None:
return pid, {}

try:
tokens = shlex.split(args)
except ValueError:
return pid, {}

all_args = {}
for i in range(len(args)):
space = args[i].find(' ')
splits = args[i][:space], args[i][space + 1:]
all_args[splits[0]] = splits[1]
i = 0
while i < len(tokens):
token = tokens[i]
if not token.startswith('--'):
i += 1
continue
key = token[2:]
i += 1
values = []
while i < len(tokens) and not tokens[i].startswith('--'):
values.append(tokens[i])
i += 1
all_args[key] = ' '.join(values) if values else 'true'
return pid, all_args

@classmethod
def kill_task(cls, task):
if task:
pid, all_args = cls.parse_info_from_cmdline(task)
log_file = all_args['log_file']
log_file = all_args.get('log_file')
if sys.platform == 'win32':
if not pid:
return [cls.refresh_tasks()] + [gr.update(value=None)]
command = ['taskkill', '/f', '/t', '/pid', pid]
else:
command = ['pkill', '-9', '-f', log_file]
if log_file:
command = ['pkill', '-9', '-f', log_file]
elif pid:
command = ['kill', '-9', pid]
else:
return [cls.refresh_tasks()] + [gr.update(value=None)]
try:
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode == 0
Expand Down