diff --git a/swift/infer_engine/infer_client.py b/swift/infer_engine/infer_client.py
index 055ecac0cc..69e51a11b3 100644
--- a/swift/infer_engine/infer_client.py
+++ b/swift/infer_engine/infer_client.py
@@ -114,8 +114,9 @@ def _parse_stream_data(data: bytes) -> Optional[str]:
data = data.strip()
if len(data) == 0:
return
- assert data.startswith('data:'), f'data: {data}'
- return data[5:].strip()
+ if data.startswith('data:'):
+ return data[5:].strip()
+ return data
async def infer_async(
self,
@@ -138,6 +139,15 @@ async def infer_async(
async def _gen_stream() -> AsyncIterator[ChatCompletionStreamResponse]:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp:
+ if resp.status >= 400 or resp.content_type != 'text/event-stream':
+ data = await resp.text()
+ try:
+ resp_obj = json.loads(data)
+ except json.JSONDecodeError:
+ raise HTTPError(data)
+ if resp_obj.get('object') == 'error':
+ raise HTTPError(resp_obj['message'])
+ raise HTTPError(data)
async for data in resp.content:
data = self._parse_stream_data(data)
if data == '[DONE]':
diff --git a/swift/infer_engine/infer_engine.py b/swift/infer_engine/infer_engine.py
index b6a2e932eb..795eba8670 100644
--- a/swift/infer_engine/infer_engine.py
+++ b/swift/infer_engine/infer_engine.py
@@ -84,10 +84,8 @@ async def _run_async_iter():
async for item in await async_iter:
queue.put(item)
except Exception as e:
- if getattr(self, 'strict', True):
- raise
queue.put(e)
- else:
+ finally:
queue.put(None)
try:
@@ -103,6 +101,8 @@ async def _run_async_iter():
if output is None or isinstance(output, Exception):
prog_bar.update()
self._update_metrics(pre_output, metrics)
+ if isinstance(output, Exception) and getattr(self, 'strict', True):
+ raise output
return
pre_output = output
yield output
diff --git a/swift/infer_engine/transformers_engine.py b/swift/infer_engine/transformers_engine.py
index ac7e963717..7558b7acfe 100644
--- a/swift/infer_engine/transformers_engine.py
+++ b/swift/infer_engine/transformers_engine.py
@@ -215,6 +215,41 @@ def _update_batched_logprobs(batched_logprobs: List[torch.Tensor], logits_stream
for logprobs, new_logprobs in zip(batched_logprobs, new_batched_logprobs):
logprobs += new_logprobs
+ @staticmethod
+ def _extract_reasoning_content(response: str) -> tuple[Optional[str], str]:
+ if '' not in response:
+ return None, response
+ _, suffix = response.split('', 1)
+ if '' not in suffix:
+ return suffix.lstrip('\n'), ''
+ reasoning_content, content = suffix.split('', 1)
+ return reasoning_content.lstrip('\n'), content.lstrip('\n')
+
+ @classmethod
+ def _extract_reasoning_delta(cls, previous_text: str, current_text: str) -> tuple[Optional[str], Optional[str]]:
+ previous_reasoning, previous_content = cls._extract_reasoning_content(previous_text)
+ current_reasoning, current_content = cls._extract_reasoning_content(current_text)
+
+ delta_reasoning_content = None
+ if current_reasoning is not None:
+ previous_reasoning = previous_reasoning or ''
+ if current_reasoning.startswith(previous_reasoning):
+ delta_reasoning_content = current_reasoning[len(previous_reasoning):]
+ else:
+ delta_reasoning_content = current_reasoning
+ if not delta_reasoning_content:
+ delta_reasoning_content = None
+
+ delta_content = None
+ if current_content:
+ if current_content.startswith(previous_content):
+ delta_content = current_content[len(previous_content):]
+ else:
+ delta_content = current_content
+ if not delta_content:
+ delta_content = None
+ return delta_reasoning_content, delta_content
+
def _infer_stream(self, inputs: Dict[str, Any], *, generation_config: GenerationConfig,
adapter_request: Optional[AdapterRequest], request_config: RequestConfig,
**kwargs) -> Iterator[List[Optional[ChatCompletionStreamResponse]]]:
@@ -251,6 +286,7 @@ def _model_generate(**kwargs):
infer_streamers = [InferStreamer(self.template) for _ in range(batch_size)]
request_id_list = [f'chatcmpl-{random_uuid()}' for _ in range(batch_size)]
token_idxs = [0] * batch_size
+ response_texts = [''] * batch_size
raw_batched_generate_ids = None # or torch.Tensor: [batch_size, seq_len]
batched_logprobs = [[] for _ in range(batch_size)]
@@ -295,17 +331,30 @@ def _model_generate(**kwargs):
logprobs = self._get_logprobs(logprobs_list, generate_ids[token_idxs[i]:], request_config.top_logprobs)
token_idxs[i] = len(generate_ids)
+ previous_text = response_texts[i]
+ response_texts[i] = previous_text + (delta_text or '')
+ delta_reasoning_content, delta_content = self._extract_reasoning_delta(previous_text, response_texts[i])
+ if not delta_content and not delta_reasoning_content and not is_finished[i]:
+ res.append(None)
+ continue
+
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
toolcall = None
if is_finished[i]:
- toolcall = self._get_toolcall(self.template.decode(generate_ids))
+ response = self.template.decode(generate_ids)
+ _, content = self._extract_reasoning_content(response)
+ toolcall = self._get_toolcall(content or response)
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, usage_info.completion_tokens,
is_finished[i])
choices = [
ChatCompletionResponseStreamChoice(
index=0,
- delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
+ delta=DeltaMessage(
+ role='assistant',
+ content=delta_content,
+ reasoning_content=delta_reasoning_content,
+ tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs)
]
@@ -423,13 +472,18 @@ def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationCo
logprobs = self._get_logprobs(logprobs_list, generate_ids, request_config.top_logprobs)
usage_info = self._update_usage_info(usage_info, len(generate_ids))
response = self.template.decode(generate_ids, template_inputs=template_inputs[i])
+ reasoning_content, content = self._extract_reasoning_content(response)
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True)
- toolcall = self._get_toolcall(response)
+ toolcall = self._get_toolcall(content or response)
token_ids = generate_ids if request_config.return_details else None
choices.append(
ChatCompletionResponseChoice(
index=j,
- message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
+ message=ChatMessage(
+ role='assistant',
+ content=content,
+ reasoning_content=reasoning_content,
+ tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs,
token_ids=token_ids))
diff --git a/swift/pipelines/infer/deploy.py b/swift/pipelines/infer/deploy.py
index b81b6ca829..8de46578b3 100644
--- a/swift/pipelines/infer/deploy.py
+++ b/swift/pipelines/infer/deploy.py
@@ -146,9 +146,16 @@ def _post_process(self, request_info, response, return_cmpl_response: bool = Fal
is_finished = all(response.choices[i].finish_reason for i in range(len(response.choices)))
if 'stream' in response.__class__.__name__.lower():
- request_info['response'] += response.choices[0].delta.content
+ delta = response.choices[0].delta
+ if delta.content:
+ request_info['response'] += delta.content
+ if getattr(delta, 'reasoning_content', None):
+ request_info.setdefault('reasoning_content', '')
+ request_info['reasoning_content'] += delta.reasoning_content
else:
request_info['response'] = response.choices[0].message.content
+ if getattr(response.choices[0].message, 'reasoning_content', None):
+ request_info['reasoning_content'] = response.choices[0].message.reasoning_content
if return_cmpl_response:
response = response.to_cmpl_response()
if is_finished:
diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py
index f729bef133..09284c5cc1 100644
--- a/swift/ui/llm_infer/llm_infer.py
+++ b/swift/ui/llm_infer/llm_infer.py
@@ -44,6 +44,16 @@ class LLMInfer(BaseUI):
'en': 'Port'
},
},
+ 'api_key': {
+ 'label': {
+ 'zh': '接口token',
+ 'en': 'API key'
+ },
+ 'info': {
+ 'zh': '部署服务使用的API key,聊天时会自动复用',
+ 'en': 'API key used by the deployed service and reused for chat requests'
+ }
+ },
'llm_infer': {
'label': {
'zh': 'LLM推理',
@@ -140,6 +150,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
scale=8)
infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4)
gr.Textbox(elem_id='port', lines=1, value='8000', scale=4)
+ gr.Textbox(elem_id='api_key', lines=1, scale=6)
chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height')
with gr.Row(equal_height=True):
prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True)
@@ -388,12 +399,16 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au
infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + prompt
_, args = Runtime.parse_info_from_cmdline(running_task)
+ if 'port' not in args:
+ raise gr.Error('Please select a valid running deployment first.')
request_config = RequestConfig(
temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
request_config.stream = True
request_config.stop = ['Observation:']
request_config.max_tokens = max_new_tokens
stream_resp_with_history = ''
+ stream_reasoning_content = ''
+ stream_response_content = ''
response = ''
i = len(infer_request.messages) - 1
for i in range(len(infer_request.messages) - 1, -1, -1):
@@ -412,14 +427,29 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au
if infer_model_type:
model_kwargs = {'model': infer_model_type}
gen_list = InferClient(
- port=args['port'], ).infer(
+ port=args['port'],
+ api_key=args.get('api_key', 'EMPTY'),
+ ).infer(
infer_requests=[_infer_request], request_config=request_config, **model_kwargs)
if infer_request.messages[-1]['role'] != 'assistant':
infer_request.messages.append({'role': 'assistant', 'content': ''})
for chunk in gen_list[0]:
if chunk is None:
continue
- stream_resp_with_history += chunk.choices[0].delta.content if chat else chunk.choices[0].text
+ if chat:
+ delta = chunk.choices[0].delta
+ if delta.reasoning_content:
+ stream_reasoning_content += delta.reasoning_content
+ if delta.content:
+ stream_response_content += delta.content
+ if stream_reasoning_content and stream_response_content:
+ stream_resp_with_history = f'\n{stream_reasoning_content}\n{stream_response_content}'
+ elif stream_reasoning_content:
+ stream_resp_with_history = f'\n{stream_reasoning_content}'
+ else:
+ stream_resp_with_history = stream_response_content
+ else:
+ stream_resp_with_history += chunk.choices[0].text
infer_request.messages[-1]['content'] = stream_resp_with_history
chatbot_content = cls._replace_tag_with_media(infer_request)
chatbot_content = cls.parse_text(chatbot_content)
diff --git a/swift/ui/llm_infer/runtime.py b/swift/ui/llm_infer/runtime.py
index 6c190ba2fd..525844710b 100644
--- a/swift/ui/llm_infer/runtime.py
+++ b/swift/ui/llm_infer/runtime.py
@@ -4,6 +4,7 @@
import os.path
import psutil
import re
+import shlex
import subprocess
import sys
import time
@@ -122,8 +123,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
def break_log_event(cls, task):
if not task:
return
- pid, all_args = cls.parse_info_from_cmdline(task)
- cls.log_event[all_args['log_file']] = True
+ _, all_args = cls.parse_info_from_cmdline(task)
+ log_file = all_args.get('log_file')
+ if not log_file:
+ return
+ cls.log_event[log_file] = True
@classmethod
def update_log(cls):
@@ -134,7 +138,9 @@ def wait(cls, task):
if not task:
return [None]
_, args = cls.parse_info_from_cmdline(task)
- log_file = args['log_file']
+ log_file = args.get('log_file')
+ if not log_file:
+ return [None]
cls.log_event[log_file] = False
offset = 0
latest_data = ''
@@ -230,29 +236,61 @@ def construct_running_task(proc):
@classmethod
def parse_info_from_cmdline(cls, task):
pid = None
- for i in range(3):
- slash = task.find('/')
- if i == 0:
- pid = task[:slash].split(':')[1]
- task = task[slash + 1:]
- args = task.split(f'swift {cls.cmd}')[1]
- args = [arg.strip() for arg in args.split('--') if arg.strip()]
+ if not isinstance(task, str) or not task:
+ return pid, {}
+
+ pid_match = re.search(r'(?:^|/)pid:(\d+)', task)
+ if pid_match:
+ pid = pid_match.group(1)
+
+ cmdline = task.split('/cmd:', 1)[1] if '/cmd:' in task else task
+ args = None
+ if f'swift {cls.cmd}' in cmdline:
+ args = cmdline.split(f'swift {cls.cmd}', 1)[1]
+ else:
+ deploy_match = re.search(rf'\S*{re.escape(cls.cmd)}\.py(?=\s|$)', cmdline)
+ if deploy_match:
+ args = cmdline[deploy_match.end():]
+ if args is None:
+ return pid, {}
+
+ try:
+ tokens = shlex.split(args)
+ except ValueError:
+ return pid, {}
+
all_args = {}
- for i in range(len(args)):
- space = args[i].find(' ')
- splits = args[i][:space], args[i][space + 1:]
- all_args[splits[0]] = splits[1]
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+ if not token.startswith('--'):
+ i += 1
+ continue
+ key = token[2:]
+ i += 1
+ values = []
+ while i < len(tokens) and not tokens[i].startswith('--'):
+ values.append(tokens[i])
+ i += 1
+ all_args[key] = ' '.join(values) if values else 'true'
return pid, all_args
@classmethod
def kill_task(cls, task):
if task:
pid, all_args = cls.parse_info_from_cmdline(task)
- log_file = all_args['log_file']
+ log_file = all_args.get('log_file')
if sys.platform == 'win32':
+ if not pid:
+ return [cls.refresh_tasks()] + [gr.update(value=None)]
command = ['taskkill', '/f', '/t', '/pid', pid]
else:
- command = ['pkill', '-9', '-f', log_file]
+ if log_file:
+ command = ['pkill', '-9', '-f', log_file]
+ elif pid:
+ command = ['kill', '-9', pid]
+ else:
+ return [cls.refresh_tasks()] + [gr.update(value=None)]
try:
result = subprocess.run(command, capture_output=True, text=True)
assert result.returncode == 0