diff --git a/libs/chatchat-server/chatchat/server/api_server/api_schemas.py b/libs/chatchat-server/chatchat/server/api_server/api_schemas.py index 32acc6d60c..296bb10205 100644 --- a/libs/chatchat-server/chatchat/server/api_server/api_schemas.py +++ b/libs/chatchat-server/chatchat/server/api_server/api_schemas.py @@ -115,6 +115,7 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput): class OpenAIBaseOutput(BaseModel): id: Optional[str] = None content: Optional[str] = None + reasoning_content: Optional[str] = None model: Optional[str] = None object: Literal[ "chat.completion", "chat.completion.chunk" @@ -150,6 +151,7 @@ def model_dump(self) -> dict: { "delta": { "content": self.content, + "reasoning_content": self.reasoning_content, "tool_calls": self.tool_calls, }, "role": self.role, diff --git a/libs/chatchat-server/chatchat/server/chat/deepseek.py b/libs/chatchat-server/chatchat/server/chat/deepseek.py new file mode 100644 index 0000000000..09a0eea699 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/chat/deepseek.py @@ -0,0 +1,152 @@ +import asyncio +import logging +from typing import Any, Optional, Iterator +from typing import AsyncIterator +from typing import Any, Dict, Iterator, List, Optional +from typing_extensions import List, TypedDict + +from langchain.schema import HumanMessage, AIMessage, SystemMessage,ChatMessage +from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGenerationChunk, LLMResult +from langchain_core.callbacks import CallbackManagerForLLMRun + +from langchain_openai import ChatOpenAI + +logger = logging.getLogger(__name__) + +class DeepseekChatOpenAI(ChatOpenAI): + async def _astream( + self, + messages: Any, + stop: Optional[Any] = None, + run_manager: Optional[Any] = None, + **kwargs: Any, + ) -> AsyncIterator[AIMessageChunk]: + openai_messages = [] + for msg in messages: + if isinstance(msg, HumanMessage): + openai_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AIMessage): + openai_messages.append({"role": "assistant", "content": msg.content}) + elif isinstance(msg, SystemMessage): + openai_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, ChatMessage): + openai_messages.append({"role": msg.role, "content": msg.content}) + else: + raise ValueError(f"Unsupported message type: {type(msg)}") + + params = { + "model": self.model_name, + "messages": openai_messages, + **self.model_kwargs, + **kwargs, + "extra_body": { + "enable_enhanced_generation": True, + **(kwargs.get("extra_body", {})), + **(self.model_kwargs.get("extra_body", {})) + } + } + params = {k: v for k, v in params.items() if v not in (None, {}, [])} + + # Create and process the stream + async for chunk in await self.async_client.create( + stream=True, + **params + ): + content = chunk.choices[0].delta.content or "" + reasoning = chunk.choices[0].delta.model_extra.get("reasoning_content", "") if chunk.choices[ + 0].delta.model_extra else "" + if content: + yield ChatGenerationChunk( + message=AIMessageChunk(content=content), + generation_info={"reasoning_content": reasoning} + ) + if reasoning: + chunk=ChatGenerationChunk( + message=AIMessageChunk( + content="", + additional_kwargs={"reasoning_content": reasoning} + ), + generation_info={"reasoning_content": reasoning} + ) + yield chunk + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + openai_messages = [] + for msg in messages: + if isinstance(msg, HumanMessage): + openai_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AIMessage): + openai_messages.append({"role": "assistant", "content": msg.content}) + elif isinstance(msg, SystemMessage): + openai_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, ChatMessage): + openai_messages.append({"role": msg.role, "content": msg.content}) + else: + raise ValueError(f"Unsupported message type: {type(msg)}") + + params = { + "model": self.model_name, + "messages": openai_messages, + **self.model_kwargs, + **kwargs, + "extra_body": { + "enable_enhanced_generation": True, + **(kwargs.get("extra_body", {})), + **(self.model_kwargs.get("extra_body", {})) + } + } + params = {k: v for k, v in params.items() if v not in (None, {}, [])} + + # Create and process the stream + for chunk in self.client.create( + stream=True, + **params + ): + content = chunk.choices[0].delta.content or "" + reasoning = chunk.choices[0].delta.model_extra.get("reasoning_content", "") if chunk.choices[ + 0].delta.model_extra else "" + if content: + yield ChatGenerationChunk( + message=AIMessageChunk(content=content), + generation_info={"reasoning_content": reasoning} + ) + if reasoning: + yield ChatGenerationChunk( + message=AIMessageChunk( + content="", + additional_kwargs={"reasoning_content": reasoning} + ), + generation_info={"reasoning_content": reasoning} + ) + + def invoke( + self, + messages: Any, + stop: Optional[Any] = None, + run_manager: Optional[Any] = None, + **kwargs: Any, + ) -> AIMessage: + + async def _ainvoke(): + combined_content = [] + combined_reasoning = [] + async for chunk in self._astream(messages, stop, run_manager, **kwargs): + if chunk.message.content: + combined_content.append(chunk.message.content) + # If reasoning is in additional_kwargs, gather that too + if "reasoning_content" in chunk.message.additional_kwargs: + combined_reasoning.append( + chunk.message.additional_kwargs["reasoning_content"] + ) + return AIMessage( + content="".join(combined_content), + additional_kwargs={"reasoning_content": "".join(combined_reasoning)} if combined_reasoning else {} + ) + + return asyncio.run(_ainvoke()) \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/chat/kb_chat.py b/libs/chatchat-server/chatchat/server/chat/kb_chat.py index 3aa1ce4e54..e969d0cd60 100644 --- a/libs/chatchat-server/chatchat/server/chat/kb_chat.py +++ b/libs/chatchat-server/chatchat/server/chat/kb_chat.py @@ -122,9 +122,6 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: ) .model_dump_json() return - callback = AsyncIteratorCallbackHandler() - callbacks = [callback] - # Enable langchain-chatchat to support langfuse import os langfuse_secret_key = os.environ.get('LANGFUSE_SECRET_KEY') @@ -142,8 +139,7 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: llm = get_ChatOpenAI( model_name=model, temperature=temperature, - max_tokens=max_tokens, - callbacks=callbacks, + max_tokens=max_tokens ) # TODO: 视情况使用 API # # 加入reranker @@ -171,12 +167,6 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: chain = chat_prompt | llm - # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.ainvoke({"context": context, "question": query}), - callback.done), - ) - if len(source_documents) == 0: # 没有找到相关文档 source_documents.append(f"未找到相关文档,该回答为大模型自身能力解答!") @@ -191,20 +181,38 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: docs=source_documents, ) yield ret.model_dump_json() - - async for token in callback.aiter(): - ret = OpenAIChatOutput( - id=f"chat{uuid.uuid4()}", - object="chat.completion.chunk", - content=token, - role="assistant", - model=model, - ) + + async for chunk in chain.astream({"context": context, "question": query}): + if chunk.additional_kwargs.get("reasoning_content"): + reasoning_token = chunk.additional_kwargs["reasoning_content"] + if reasoning_token: + ret = OpenAIChatOutput( + id=f"chat{uuid.uuid4()}", + object="chat.completion.chunk", + reasoning_content=reasoning_token, + role="assistant", + model=model, + ) + # Otherwise, treat it as an answer token + else: + ret = OpenAIChatOutput( + id=f"chat{uuid.uuid4()}", + object="chat.completion.chunk", + content=chunk.content, + role="assistant", + model=model, + ) yield ret.model_dump_json() else: answer = "" - async for token in callback.aiter(): - answer += token + async for chunk in chain.astream({"context": context, "question": query}): + if chunk.additional_kwargs.get("reasoning_content"): + reasoning_token = chunk.additional_kwargs["reasoning_content"] + if reasoning_token: + answer += reasoning_token + else: + answer += chunk.content + ret = OpenAIChatOutput( id=f"chat{uuid.uuid4()}", object="chat.completion", @@ -213,7 +221,6 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]: model=model, ) yield ret.model_dump_json() - await task except asyncio.exceptions.CancelledError: logger.warning("streaming progress has been interrupted by user.") return diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index d58d4a4c0e..87ee610c78 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -26,6 +26,7 @@ from langchain.tools import BaseTool from langchain_core.embeddings import Embeddings from langchain_openai.chat_models import ChatOpenAI +from chatchat.server.chat.deepseek import DeepseekChatOpenAI from langchain_openai.llms import OpenAI from memoization import cached, CachingAlgorithmFlag @@ -225,7 +226,7 @@ def get_ChatOpenAI( verbose: bool = True, local_wrap: bool = False, # use local wrapped api **kwargs: Any, -) -> ChatOpenAI: +) -> DeepseekChatOpenAI: model_info = get_model_info(model_name) params = dict( streaming=streaming, @@ -253,7 +254,7 @@ def get_ChatOpenAI( openai_api_key=model_info.get("api_key"), openai_proxy=model_info.get("api_proxy"), ) - model = ChatOpenAI(**params) + model = DeepseekChatOpenAI(**params) except Exception as e: logger.exception(f"failed to create ChatOpenAI for model: {model_name}.") model = None @@ -817,7 +818,7 @@ def get_httpx_client( default_proxies.update(proxies) # construct Client - kwargs.update(timeout=timeout, proxies=default_proxies) + kwargs.update(timeout=timeout) if use_async: return httpx.AsyncClient(**kwargs) diff --git a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index a45d886798..4c44921fa6 100644 --- a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -395,6 +395,7 @@ def on_conv_change(): chat_box.ai_say("正在思考...") text = "" + reasoning_text= "" started = False client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE") @@ -519,10 +520,24 @@ def on_conv_change(): for img in d.tool_output.get("images", []): chat_box.insert_msg(Image(f"{api.base_url}/media/{img}"), pos=-2) else: - text += d.choices[0].delta.content or "" - chat_box.update_msg( - text.replace("\n", "\n\n"), streaming=True, metadata=metadata - ) + reasoning_content = getattr(d.choices[0].delta, "reasoning_content", None) + if reasoning_content: + if reasoning_text=="": + chat_box.insert_msg( + Markdown("...", in_expander=True, title="深度思考", state="running", expanded=True) + ) + reasoning_text += reasoning_content + chat_box.update_msg(reasoning_text, streaming=True, state="running") + continue + else: + content = getattr(d.choices[0].delta, "content", None) + if content: + if text=="" and reasoning_text!="": + #正式答案开始首次输出后,结束之前的深度思考 + chat_box.update_msg(reasoning_text, streaming=False, state="complete") + chat_box.insert_msg("") + text += content + chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True) chat_box.update_msg(text, streaming=False, metadata=metadata) except Exception as e: st.error(e.body) diff --git a/libs/chatchat-server/chatchat/webui_pages/kb_chat.py b/libs/chatchat-server/chatchat/webui_pages/kb_chat.py index bcadb82778..3925861fd4 100644 --- a/libs/chatchat-server/chatchat/webui_pages/kb_chat.py +++ b/libs/chatchat-server/chatchat/webui_pages/kb_chat.py @@ -219,6 +219,7 @@ def on_conv_change(): ]) text = "" + reasoning_text="" first = True try: @@ -228,8 +229,23 @@ def on_conv_change(): chat_box.update_msg("", streaming=False) first = False continue - text += d.choices[0].delta.content or "" - chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True) + reasoning_content = getattr(d.choices[0].delta, "reasoning_content", None) + if reasoning_content: + if reasoning_text=="": + chat_box.insert_msg( + Markdown("...", in_expander=True, title="深度思考", state="running", expanded=True) + ) + reasoning_text += reasoning_content + chat_box.update_msg(reasoning_text, streaming=True, state="running") + continue + else: + content = getattr(d.choices[0].delta, "content", None) + if content: + if text=="" and reasoning_text!="": + chat_box.update_msg(reasoning_text, streaming=False, state="complete") + chat_box.insert_msg("") + text += content + chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True) chat_box.update_msg(text, streaming=False) # TODO: 搜索未配置API KEY时产生报错 except Exception as e: