Skip to content

Commit c909cb7

Browse files
committed
feat: add chart reasoning content
1 parent e69caad commit c909cb7

File tree

12 files changed

+591
-184
lines changed

12 files changed

+591
-184
lines changed
Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import json
2-
from pydantic import BaseModel
3-
from typing import Optional, Dict, Any, Type
42
from abc import ABC, abstractmethod
5-
from langchain_core.language_models import BaseLLM as LangchainBaseLLM
6-
from langchain_openai import ChatOpenAI
3+
from typing import Optional, Dict, Any, Type
4+
5+
from langchain.chat_models.base import BaseChatModel
6+
from pydantic import BaseModel
77
from sqlmodel import Session, select
88

9-
from common.core.db import engine
9+
from apps.ai_model.openai.llm import BaseChatOpenAI
1010
from apps.system.models.system_model import AiModelDetail
11+
from common.core.db import engine
1112

1213

1314
# from langchain_community.llms import Tongyi, VLLM
@@ -20,76 +21,57 @@ class LLMConfig(BaseModel):
2021
api_key: Optional[str] = None
2122
api_base_url: Optional[str] = None
2223
additional_params: Dict[str, Any] = {}
23-
24+
2425

2526
class BaseLLM(ABC):
2627
"""Abstract base class for large language models"""
27-
28+
2829
def __init__(self, config: LLMConfig):
2930
self.config = config
3031
self._llm = self._init_llm()
31-
32+
3233
@abstractmethod
33-
def _init_llm(self) -> LangchainBaseLLM:
34+
def _init_llm(self) -> BaseChatModel:
3435
"""Initialize specific large language model instance"""
3536
pass
36-
37+
3738
@property
38-
def llm(self) -> LangchainBaseLLM:
39+
def llm(self) -> BaseChatModel:
3940
"""Return the langchain LLM instance"""
4041
return self._llm
4142

43+
4244
class OpenAILLM(BaseLLM):
43-
def _init_llm(self) -> LangchainBaseLLM:
44-
return ChatOpenAI(
45+
def _init_llm(self) -> BaseChatModel:
46+
return BaseChatOpenAI(
4547
model=self.config.model_name,
4648
api_key=self.config.api_key,
4749
base_url=self.config.api_base_url,
4850
stream_usage=True,
49-
**self.config.additional_params
51+
**self.config.additional_params,
52+
extra_body={"enable_thinking": True},
5053
)
51-
52-
def generate(self, prompt: str) -> str:
53-
return self.llm.invoke(prompt)
5454

55-
""" class TongyiLLM(BaseLLM):
56-
def _init_llm(self) -> LangchainBaseLLM:
57-
return Tongyi(
58-
model_name=self.config.model_name,
59-
dashscope_api_key=self.config.api_key,
60-
**self.config.additional_params
61-
)
62-
6355
def generate(self, prompt: str) -> str:
6456
return self.llm.invoke(prompt)
6557

66-
class VLLMLLM(BaseLLM):
67-
def _init_llm(self) -> LangchainBaseLLM:
68-
return VLLM(
69-
model=self.config.model_name,
70-
**self.config.additional_params
71-
)
72-
73-
def generate(self, prompt: str) -> str:
74-
return self.llm.invoke(prompt) """
75-
7658

7759
class LLMFactory:
7860
"""Large Language Model Factory Class"""
79-
61+
8062
_llm_types: Dict[str, Type[BaseLLM]] = {
8163
"openai": OpenAILLM,
8264
"tongyi": OpenAILLM,
8365
"vllm": OpenAILLM
8466
}
85-
67+
8668
@classmethod
8769
def create_llm(cls, config: LLMConfig) -> BaseLLM:
8870
llm_class = cls._llm_types.get(config.model_type)
8971
if not llm_class:
9072
raise ValueError(f"Unsupported LLM type: {config.model_type}")
9173
return llm_class(config)
92-
74+
9375
@classmethod
9476
def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]):
9577
"""Register new model type"""
@@ -107,6 +89,7 @@ def register_llm(cls, model_type: str, llm_class: Type[BaseLLM]):
10789
)
10890
return config """
10991

92+
11093
def get_default_config() -> LLMConfig:
11194
with Session(engine) as session:
11295
db_model = session.exec(
@@ -130,6 +113,5 @@ def get_default_config() -> LLMConfig:
130113
model_name=db_model.base_model,
131114
api_key=db_model.api_key,
132115
api_base_url=db_model.api_domain,
133-
additional_params=additional_params
116+
additional_params=additional_params,
134117
)
135-
Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,167 @@
1-
# todo
1+
from typing import Dict, Optional, Any, Iterator, cast, Mapping
2+
3+
from langchain_core.language_models import LanguageModelInput
4+
from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \
5+
SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk
6+
from langchain_core.messages.ai import UsageMetadata
7+
from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk
8+
from langchain_core.outputs import ChatGenerationChunk
9+
from langchain_core.runnables import RunnableConfig, ensure_config
10+
from langchain_openai import ChatOpenAI
11+
from langchain_openai.chat_models.base import _create_usage_metadata
12+
13+
14+
def _convert_delta_to_message_chunk(
15+
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
16+
) -> BaseMessageChunk:
17+
id_ = _dict.get("id")
18+
role = cast(str, _dict.get("role"))
19+
content = cast(str, _dict.get("content") or "")
20+
additional_kwargs: dict = {}
21+
if 'reasoning_content' in _dict:
22+
additional_kwargs['reasoning_content'] = _dict.get('reasoning_content')
23+
if _dict.get("function_call"):
24+
function_call = dict(_dict["function_call"])
25+
if "name" in function_call and function_call["name"] is None:
26+
function_call["name"] = ""
27+
additional_kwargs["function_call"] = function_call
28+
tool_call_chunks = []
29+
if raw_tool_calls := _dict.get("tool_calls"):
30+
additional_kwargs["tool_calls"] = raw_tool_calls
31+
try:
32+
tool_call_chunks = [
33+
tool_call_chunk(
34+
name=rtc["function"].get("name"),
35+
args=rtc["function"].get("arguments"),
36+
id=rtc.get("id"),
37+
index=rtc["index"],
38+
)
39+
for rtc in raw_tool_calls
40+
]
41+
except KeyError:
42+
pass
43+
44+
if role == "user" or default_class == HumanMessageChunk:
45+
return HumanMessageChunk(content=content, id=id_)
46+
elif role == "assistant" or default_class == AIMessageChunk:
47+
return AIMessageChunk(
48+
content=content,
49+
additional_kwargs=additional_kwargs,
50+
id=id_,
51+
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
52+
)
53+
elif role in ("system", "developer") or default_class == SystemMessageChunk:
54+
if role == "developer":
55+
additional_kwargs = {"__openai_role__": "developer"}
56+
else:
57+
additional_kwargs = {}
58+
return SystemMessageChunk(
59+
content=content, id=id_, additional_kwargs=additional_kwargs
60+
)
61+
elif role == "function" or default_class == FunctionMessageChunk:
62+
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
63+
elif role == "tool" or default_class == ToolMessageChunk:
64+
return ToolMessageChunk(
65+
content=content, tool_call_id=_dict["tool_call_id"], id=id_
66+
)
67+
elif role or default_class == ChatMessageChunk:
68+
return ChatMessageChunk(content=content, role=role, id=id_)
69+
else:
70+
return default_class(content=content, id=id_)
71+
72+
73+
class BaseChatOpenAI(ChatOpenAI):
74+
usage_metadata: dict = {}
75+
76+
# custom_get_token_ids = custom_get_token_ids
77+
78+
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
79+
return self.usage_metadata
80+
81+
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
82+
kwargs['stream_usage'] = True
83+
for chunk in super()._stream(*args, **kwargs):
84+
if chunk.message.usage_metadata is not None:
85+
self.usage_metadata = chunk.message.usage_metadata
86+
yield chunk
87+
88+
def _convert_chunk_to_generation_chunk(
89+
self,
90+
chunk: dict,
91+
default_chunk_class: type,
92+
base_generation_info: Optional[dict],
93+
) -> Optional[ChatGenerationChunk]:
94+
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
95+
return None
96+
token_usage = chunk.get("usage")
97+
choices = (
98+
chunk.get("choices", [])
99+
# from beta.chat.completions.stream
100+
or chunk.get("chunk", {}).get("choices", [])
101+
)
102+
103+
usage_metadata: Optional[UsageMetadata] = (
104+
_create_usage_metadata(token_usage) if token_usage and token_usage.get("prompt_tokens") else None
105+
)
106+
if len(choices) == 0:
107+
# logprobs is implicitly None
108+
generation_chunk = ChatGenerationChunk(
109+
message=default_chunk_class(content="", usage_metadata=usage_metadata)
110+
)
111+
return generation_chunk
112+
113+
choice = choices[0]
114+
if choice["delta"] is None:
115+
return None
116+
117+
message_chunk = _convert_delta_to_message_chunk(
118+
choice["delta"], default_chunk_class
119+
)
120+
generation_info = {**base_generation_info} if base_generation_info else {}
121+
122+
if finish_reason := choice.get("finish_reason"):
123+
generation_info["finish_reason"] = finish_reason
124+
if model_name := chunk.get("model"):
125+
generation_info["model_name"] = model_name
126+
if system_fingerprint := chunk.get("system_fingerprint"):
127+
generation_info["system_fingerprint"] = system_fingerprint
128+
129+
logprobs = choice.get("logprobs")
130+
if logprobs:
131+
generation_info["logprobs"] = logprobs
132+
133+
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
134+
message_chunk.usage_metadata = usage_metadata
135+
136+
generation_chunk = ChatGenerationChunk(
137+
message=message_chunk, generation_info=generation_info or None
138+
)
139+
return generation_chunk
140+
141+
def invoke(
142+
self,
143+
input: LanguageModelInput,
144+
config: Optional[RunnableConfig] = None,
145+
*,
146+
stop: Optional[list[str]] = None,
147+
**kwargs: Any,
148+
) -> BaseMessage:
149+
config = ensure_config(config)
150+
chat_result = cast(
151+
"ChatGeneration",
152+
self.generate_prompt(
153+
[self._convert_input(input)],
154+
stop=stop,
155+
callbacks=config.get("callbacks"),
156+
tags=config.get("tags"),
157+
metadata=config.get("metadata"),
158+
run_name=config.get("run_name"),
159+
run_id=config.pop("run_id", None),
160+
**kwargs,
161+
).generations[0][0],
162+
163+
).message
164+
165+
self.usage_metadata = chat_result.response_metadata[
166+
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
167+
return chat_result

backend/apps/chat/api/chat.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi.responses import StreamingResponse
55

66
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
7-
delete_chat
7+
delete_chat, get_chat_chart_data, get_chat_predict_data
88
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion
99
from apps.chat.task.llm import LLMService, run_task, run_analysis_or_predict_task, run_recommend_questions_task
1010
from common.core.deps import SessionDep, CurrentUser
@@ -28,6 +28,28 @@ async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int
2828
)
2929

3030

31+
@router.get("/record/get/{chart_record_id}/data")
32+
async def chat_record_data(session: SessionDep, chart_record_id: int):
33+
try:
34+
return get_chat_chart_data(chart_record_id=chart_record_id, session=session)
35+
except Exception as e:
36+
raise HTTPException(
37+
status_code=500,
38+
detail=str(e)
39+
)
40+
41+
42+
@router.get("/record/get/{chart_record_id}/predict_data")
43+
async def chat_predict_data(session: SessionDep, chart_record_id: int):
44+
try:
45+
return get_chat_predict_data(chart_record_id=chart_record_id, session=session)
46+
except Exception as e:
47+
raise HTTPException(
48+
status_code=500,
49+
detail=str(e)
50+
)
51+
52+
3153
@router.post("/rename")
3254
async def rename(session: SessionDep, chat: RenameChat):
3355
try:
@@ -61,7 +83,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
6183
)
6284

6385

64-
@router.get("/recommend_questions/{chat_record_id}")
86+
@router.post("/recommend_questions/{chat_record_id}")
6587
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int):
6688
try:
6789
record = session.query(ChatRecord).get(chat_record_id)
@@ -74,11 +96,14 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch
7496

7597
llm_service = LLMService(session, current_user, request_question)
7698
llm_service.set_record(record)
77-
78-
return run_recommend_questions_task(llm_service)
7999
except Exception as e:
80100
traceback.print_exc()
81-
return '[]'
101+
raise HTTPException(
102+
status_code=500,
103+
detail=str(e)
104+
)
105+
106+
return StreamingResponse(run_recommend_questions_task(llm_service), media_type="text/event-stream")
82107

83108

84109
@router.post("/question")

0 commit comments

Comments
 (0)