diff --git a/swift/infer_engine/infer_client.py b/swift/infer_engine/infer_client.py index 055ecac0cc..69e51a11b3 100644 --- a/swift/infer_engine/infer_client.py +++ b/swift/infer_engine/infer_client.py @@ -114,8 +114,9 @@ def _parse_stream_data(data: bytes) -> Optional[str]: data = data.strip() if len(data) == 0: return - assert data.startswith('data:'), f'data: {data}' - return data[5:].strip() + if data.startswith('data:'): + return data[5:].strip() + return data async def infer_async( self, @@ -138,6 +139,15 @@ async def infer_async( async def _gen_stream() -> AsyncIterator[ChatCompletionStreamResponse]: async with aiohttp.ClientSession() as session: async with session.post(url, json=request_data, **self._get_request_kwargs()) as resp: + if resp.status >= 400 or resp.content_type != 'text/event-stream': + data = await resp.text() + try: + resp_obj = json.loads(data) + except json.JSONDecodeError: + raise HTTPError(data) + if resp_obj.get('object') == 'error': + raise HTTPError(resp_obj['message']) + raise HTTPError(data) async for data in resp.content: data = self._parse_stream_data(data) if data == '[DONE]': diff --git a/swift/infer_engine/infer_engine.py b/swift/infer_engine/infer_engine.py index b6a2e932eb..795eba8670 100644 --- a/swift/infer_engine/infer_engine.py +++ b/swift/infer_engine/infer_engine.py @@ -84,10 +84,8 @@ async def _run_async_iter(): async for item in await async_iter: queue.put(item) except Exception as e: - if getattr(self, 'strict', True): - raise queue.put(e) - else: + finally: queue.put(None) try: @@ -103,6 +101,8 @@ async def _run_async_iter(): if output is None or isinstance(output, Exception): prog_bar.update() self._update_metrics(pre_output, metrics) + if isinstance(output, Exception) and getattr(self, 'strict', True): + raise output return pre_output = output yield output diff --git a/swift/infer_engine/transformers_engine.py b/swift/infer_engine/transformers_engine.py index ac7e963717..7558b7acfe 100644 --- a/swift/infer_engine/transformers_engine.py +++ b/swift/infer_engine/transformers_engine.py @@ -215,6 +215,41 @@ def _update_batched_logprobs(batched_logprobs: List[torch.Tensor], logits_stream for logprobs, new_logprobs in zip(batched_logprobs, new_batched_logprobs): logprobs += new_logprobs + @staticmethod + def _extract_reasoning_content(response: str) -> tuple[Optional[str], str]: + if '' not in response: + return None, response + _, suffix = response.split('', 1) + if '' not in suffix: + return suffix.lstrip('\n'), '' + reasoning_content, content = suffix.split('', 1) + return reasoning_content.lstrip('\n'), content.lstrip('\n') + + @classmethod + def _extract_reasoning_delta(cls, previous_text: str, current_text: str) -> tuple[Optional[str], Optional[str]]: + previous_reasoning, previous_content = cls._extract_reasoning_content(previous_text) + current_reasoning, current_content = cls._extract_reasoning_content(current_text) + + delta_reasoning_content = None + if current_reasoning is not None: + previous_reasoning = previous_reasoning or '' + if current_reasoning.startswith(previous_reasoning): + delta_reasoning_content = current_reasoning[len(previous_reasoning):] + else: + delta_reasoning_content = current_reasoning + if not delta_reasoning_content: + delta_reasoning_content = None + + delta_content = None + if current_content: + if current_content.startswith(previous_content): + delta_content = current_content[len(previous_content):] + else: + delta_content = current_content + if not delta_content: + delta_content = None + return delta_reasoning_content, delta_content + def _infer_stream(self, inputs: Dict[str, Any], *, generation_config: GenerationConfig, adapter_request: Optional[AdapterRequest], request_config: RequestConfig, **kwargs) -> Iterator[List[Optional[ChatCompletionStreamResponse]]]: @@ -251,6 +286,7 @@ def _model_generate(**kwargs): infer_streamers = [InferStreamer(self.template) for _ in range(batch_size)] request_id_list = [f'chatcmpl-{random_uuid()}' for _ in range(batch_size)] token_idxs = [0] * batch_size + response_texts = [''] * batch_size raw_batched_generate_ids = None # or torch.Tensor: [batch_size, seq_len] batched_logprobs = [[] for _ in range(batch_size)] @@ -295,17 +331,30 @@ def _model_generate(**kwargs): logprobs = self._get_logprobs(logprobs_list, generate_ids[token_idxs[i]:], request_config.top_logprobs) token_idxs[i] = len(generate_ids) + previous_text = response_texts[i] + response_texts[i] = previous_text + (delta_text or '') + delta_reasoning_content, delta_content = self._extract_reasoning_delta(previous_text, response_texts[i]) + if not delta_content and not delta_reasoning_content and not is_finished[i]: + res.append(None) + continue + usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids)) toolcall = None if is_finished[i]: - toolcall = self._get_toolcall(self.template.decode(generate_ids)) + response = self.template.decode(generate_ids) + _, content = self._extract_reasoning_content(response) + toolcall = self._get_toolcall(content or response) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, usage_info.completion_tokens, is_finished[i]) choices = [ ChatCompletionResponseStreamChoice( index=0, - delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall), + delta=DeltaMessage( + role='assistant', + content=delta_content, + reasoning_content=delta_reasoning_content, + tool_calls=toolcall), finish_reason=finish_reason, logprobs=logprobs) ] @@ -423,13 +472,18 @@ def _infer_full(self, inputs: Dict[str, Any], *, generation_config: GenerationCo logprobs = self._get_logprobs(logprobs_list, generate_ids, request_config.top_logprobs) usage_info = self._update_usage_info(usage_info, len(generate_ids)) response = self.template.decode(generate_ids, template_inputs=template_inputs[i]) + reasoning_content, content = self._extract_reasoning_content(response) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True) - toolcall = self._get_toolcall(response) + toolcall = self._get_toolcall(content or response) token_ids = generate_ids if request_config.return_details else None choices.append( ChatCompletionResponseChoice( index=j, - message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), + message=ChatMessage( + role='assistant', + content=content, + reasoning_content=reasoning_content, + tool_calls=toolcall), finish_reason=finish_reason, logprobs=logprobs, token_ids=token_ids)) diff --git a/swift/pipelines/infer/deploy.py b/swift/pipelines/infer/deploy.py index b81b6ca829..8de46578b3 100644 --- a/swift/pipelines/infer/deploy.py +++ b/swift/pipelines/infer/deploy.py @@ -146,9 +146,16 @@ def _post_process(self, request_info, response, return_cmpl_response: bool = Fal is_finished = all(response.choices[i].finish_reason for i in range(len(response.choices))) if 'stream' in response.__class__.__name__.lower(): - request_info['response'] += response.choices[0].delta.content + delta = response.choices[0].delta + if delta.content: + request_info['response'] += delta.content + if getattr(delta, 'reasoning_content', None): + request_info.setdefault('reasoning_content', '') + request_info['reasoning_content'] += delta.reasoning_content else: request_info['response'] = response.choices[0].message.content + if getattr(response.choices[0].message, 'reasoning_content', None): + request_info['reasoning_content'] = response.choices[0].message.reasoning_content if return_cmpl_response: response = response.to_cmpl_response() if is_finished: diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py index f729bef133..09284c5cc1 100644 --- a/swift/ui/llm_infer/llm_infer.py +++ b/swift/ui/llm_infer/llm_infer.py @@ -44,6 +44,16 @@ class LLMInfer(BaseUI): 'en': 'Port' }, }, + 'api_key': { + 'label': { + 'zh': '接口token', + 'en': 'API key' + }, + 'info': { + 'zh': '部署服务使用的API key,聊天时会自动复用', + 'en': 'API key used by the deployed service and reused for chat requests' + } + }, 'llm_infer': { 'label': { 'zh': 'LLM推理', @@ -140,6 +150,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): scale=8) infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4) gr.Textbox(elem_id='port', lines=1, value='8000', scale=4) + gr.Textbox(elem_id='api_key', lines=1, scale=6) chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height') with gr.Row(equal_height=True): prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True) @@ -388,12 +399,16 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au infer_request.messages[-1]['content'] = infer_request.messages[-1]['content'] + prompt _, args = Runtime.parse_info_from_cmdline(running_task) + if 'port' not in args: + raise gr.Error('Please select a valid running deployment first.') request_config = RequestConfig( temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty) request_config.stream = True request_config.stop = ['Observation:'] request_config.max_tokens = max_new_tokens stream_resp_with_history = '' + stream_reasoning_content = '' + stream_response_content = '' response = '' i = len(infer_request.messages) - 1 for i in range(len(infer_request.messages) - 1, -1, -1): @@ -412,14 +427,29 @@ def send_message(cls, running_task, template_type, prompt: str, image, video, au if infer_model_type: model_kwargs = {'model': infer_model_type} gen_list = InferClient( - port=args['port'], ).infer( + port=args['port'], + api_key=args.get('api_key', 'EMPTY'), + ).infer( infer_requests=[_infer_request], request_config=request_config, **model_kwargs) if infer_request.messages[-1]['role'] != 'assistant': infer_request.messages.append({'role': 'assistant', 'content': ''}) for chunk in gen_list[0]: if chunk is None: continue - stream_resp_with_history += chunk.choices[0].delta.content if chat else chunk.choices[0].text + if chat: + delta = chunk.choices[0].delta + if delta.reasoning_content: + stream_reasoning_content += delta.reasoning_content + if delta.content: + stream_response_content += delta.content + if stream_reasoning_content and stream_response_content: + stream_resp_with_history = f'\n{stream_reasoning_content}\n{stream_response_content}' + elif stream_reasoning_content: + stream_resp_with_history = f'\n{stream_reasoning_content}' + else: + stream_resp_with_history = stream_response_content + else: + stream_resp_with_history += chunk.choices[0].text infer_request.messages[-1]['content'] = stream_resp_with_history chatbot_content = cls._replace_tag_with_media(infer_request) chatbot_content = cls.parse_text(chatbot_content) diff --git a/swift/ui/llm_infer/runtime.py b/swift/ui/llm_infer/runtime.py index 6c190ba2fd..525844710b 100644 --- a/swift/ui/llm_infer/runtime.py +++ b/swift/ui/llm_infer/runtime.py @@ -4,6 +4,7 @@ import os.path import psutil import re +import shlex import subprocess import sys import time @@ -122,8 +123,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): def break_log_event(cls, task): if not task: return - pid, all_args = cls.parse_info_from_cmdline(task) - cls.log_event[all_args['log_file']] = True + _, all_args = cls.parse_info_from_cmdline(task) + log_file = all_args.get('log_file') + if not log_file: + return + cls.log_event[log_file] = True @classmethod def update_log(cls): @@ -134,7 +138,9 @@ def wait(cls, task): if not task: return [None] _, args = cls.parse_info_from_cmdline(task) - log_file = args['log_file'] + log_file = args.get('log_file') + if not log_file: + return [None] cls.log_event[log_file] = False offset = 0 latest_data = '' @@ -230,29 +236,61 @@ def construct_running_task(proc): @classmethod def parse_info_from_cmdline(cls, task): pid = None - for i in range(3): - slash = task.find('/') - if i == 0: - pid = task[:slash].split(':')[1] - task = task[slash + 1:] - args = task.split(f'swift {cls.cmd}')[1] - args = [arg.strip() for arg in args.split('--') if arg.strip()] + if not isinstance(task, str) or not task: + return pid, {} + + pid_match = re.search(r'(?:^|/)pid:(\d+)', task) + if pid_match: + pid = pid_match.group(1) + + cmdline = task.split('/cmd:', 1)[1] if '/cmd:' in task else task + args = None + if f'swift {cls.cmd}' in cmdline: + args = cmdline.split(f'swift {cls.cmd}', 1)[1] + else: + deploy_match = re.search(rf'\S*{re.escape(cls.cmd)}\.py(?=\s|$)', cmdline) + if deploy_match: + args = cmdline[deploy_match.end():] + if args is None: + return pid, {} + + try: + tokens = shlex.split(args) + except ValueError: + return pid, {} + all_args = {} - for i in range(len(args)): - space = args[i].find(' ') - splits = args[i][:space], args[i][space + 1:] - all_args[splits[0]] = splits[1] + i = 0 + while i < len(tokens): + token = tokens[i] + if not token.startswith('--'): + i += 1 + continue + key = token[2:] + i += 1 + values = [] + while i < len(tokens) and not tokens[i].startswith('--'): + values.append(tokens[i]) + i += 1 + all_args[key] = ' '.join(values) if values else 'true' return pid, all_args @classmethod def kill_task(cls, task): if task: pid, all_args = cls.parse_info_from_cmdline(task) - log_file = all_args['log_file'] + log_file = all_args.get('log_file') if sys.platform == 'win32': + if not pid: + return [cls.refresh_tasks()] + [gr.update(value=None)] command = ['taskkill', '/f', '/t', '/pid', pid] else: - command = ['pkill', '-9', '-f', log_file] + if log_file: + command = ['pkill', '-9', '-f', log_file] + elif pid: + command = ['kill', '-9', pid] + else: + return [cls.refresh_tasks()] + [gr.update(value=None)] try: result = subprocess.run(command, capture_output=True, text=True) assert result.returncode == 0