Skip to content

Commit dfdc5a4

Browse files
perf: Data Q&A
1 parent de8bad4 commit dfdc5a4

File tree

8 files changed

+387
-64
lines changed

8 files changed

+387
-64
lines changed

backend/apps/chat/api/chat.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,65 @@
11
from fastapi import APIRouter, HTTPException
2+
from fastapi.responses import StreamingResponse
23
from sqlmodel import select
34
from apps.chat.schemas.chat_base_schema import LLMConfig
45
from apps.chat.schemas.chat_schema import ChatQuestion
5-
from apps.chat.schemas.llm import LLMService
6+
from apps.chat.schemas.llm import AgentService, LLMService
7+
from apps.datasource.models.datasource import CoreDatasource
68
from apps.system.models.system_modle import AiModelDetail
79
from common.core.deps import SessionDep
8-
# from sse_starlette.sse import EventSourceResponse
10+
from sse_starlette.sse import EventSourceResponse
11+
import json
12+
import asyncio
13+
914
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
1015

1116

1217
@router.post("/question")
1318
async def stream_sql(session: SessionDep, requestQuestion: ChatQuestion):
19+
"""Stream SQL analysis results
20+
21+
Args:
22+
session: Database session
23+
requestQuestion: User question model
24+
25+
Returns:
26+
Streaming response with analysis results
27+
"""
1428
question = requestQuestion.question
1529

16-
# Use OpenAI model
17-
""" openai_config = LLMConfig(
18-
model_type="openai",
19-
model_name="gpt-4",
20-
api_key="your-api-key",
21-
additional_params={"temperature": 0.7}
22-
)
23-
openai_service = LLMService(openai_config) """
24-
25-
aimodel = session.exec(select(AiModelDetail).where(AiModelDetail.status == True, AiModelDetail.api_key.is_not(None))).first()
30+
# Get available AI model
31+
aimodel = session.exec(select(AiModelDetail).where(
32+
AiModelDetail.status == True,
33+
AiModelDetail.api_key.is_not(None)
34+
)).first()
2635

36+
# Get available datasource
37+
ds = session.exec(select(CoreDatasource).where(
38+
CoreDatasource.status == 'Success'
39+
)).first()
40+
2741
if not aimodel:
2842
raise HTTPException(
2943
status_code=400,
3044
detail="No available AI model configuration found"
3145
)
46+
47+
if not ds:
48+
raise HTTPException(
49+
status_code=400,
50+
detail="No available datasource configuration found"
51+
)
3252

3353
# Use Tongyi Qianwen
3454
tongyi_config = LLMConfig(
35-
model_type="tongyi",
55+
model_type="openai",
3656
model_name=aimodel.name,
3757
api_key=aimodel.api_key,
58+
api_base_url=aimodel.endpoint,
3859
additional_params={"temperature": aimodel.temperature}
3960
)
40-
llm_service = LLMService(tongyi_config)
61+
# llm_service = LLMService(tongyi_config)
62+
llm_service = AgentService(tongyi_config, ds)
4163

4264
# Use Custom VLLM model
4365
""" vllm_config = LLMConfig(
@@ -49,5 +71,27 @@ async def stream_sql(session: SessionDep, requestQuestion: ChatQuestion):
4971
}
5072
)
5173
vllm_service = LLMService(vllm_config) """
52-
result = llm_service.generate_sql(question)
53-
return result
74+
""" result = llm_service.generate_sql(question)
75+
return result """
76+
77+
async def event_generator():
78+
try:
79+
async for chunk in llm_service.async_generate(question):
80+
data = json.loads(chunk.replace('data: ', ''))
81+
82+
if data['type'] in ['final', 'tool_result']:
83+
content = data['content']
84+
for char in content:
85+
yield f"data: {json.dumps({'type': 'char', 'content': char})}\n\n"
86+
await asyncio.sleep(0.05)
87+
88+
if 'html' in data:
89+
yield f"data: {json.dumps({'type': 'html', 'content': data['html']})}\n\n"
90+
else:
91+
yield chunk
92+
93+
except Exception as e:
94+
yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
95+
96+
#return EventSourceResponse(event_generator(), headers={"Content-Type": "text/event-stream"})
97+
return StreamingResponse(event_generator(), media_type="text/event-stream")

backend/apps/chat/schemas/chat_base_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _init_llm(self) -> LangchainBaseLLM:
3636
return ChatOpenAI(
3737
model=self.config.model_name,
3838
api_key=self.config.api_key,
39+
base_url=self.config.api_base_url,
3940
**self.config.additional_params
4041
)
4142

backend/apps/chat/schemas/llm.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from langchain_community.utilities import SQLDatabase
2-
# from langchain_community.agent_toolkits import create_sql_agent
3-
from langchain_community.llms import Tongyi
2+
from langgraph.prebuilt import create_react_agent
43
from langchain_core.prompts import ChatPromptTemplate
54
from apps.chat.schemas.chat_base_schema import LLMConfig, LLMFactory
5+
from apps.datasource.models.datasource import CoreDatasource
6+
from apps.db.db import exec_sql, get_uri
67
from common.core.config import settings
78
import warnings
9+
from langchain.tools import Tool
10+
from functools import partial
11+
import logging
12+
from typing import AsyncGenerator
13+
import json
14+
import asyncio
815

916
warnings.filterwarnings("ignore")
1017

@@ -31,3 +38,154 @@ def generate_sql(self, question: str) -> str:
3138
schema = self.db.get_table_info()
3239
return chain.invoke({"schema": schema, "question": question})
3340

41+
42+
class AgentService:
43+
def __init__(self, config: LLMConfig, ds: CoreDatasource):
44+
# Initialize database connection
45+
self.ds = ds
46+
db_uri = get_uri(ds)
47+
self.db = SQLDatabase.from_uri(db_uri)
48+
# self.db = SQLDatabase.from_uri(str(settings.SQLALCHEMY_DATABASE_URI))
49+
50+
# Create LLM instance through factory
51+
llm_instance = LLMFactory.create_llm(config)
52+
self.llm = llm_instance.llm
53+
54+
# Create a partial function of execute_sql with preset ds parameter
55+
# bound_execute_sql = partial(execute_sql, self.ds)
56+
bound_execute_sql = partial(execute_sql_with_db, self.db)
57+
58+
# Wrap as Tool object
59+
tools = [
60+
Tool(
61+
name="execute_sql",
62+
func=bound_execute_sql,
63+
description="""A tool for executing SQL queries.
64+
Input: SQL query statement (string)
65+
Output: Query results
66+
Example: "SELECT * FROM table_name LIMIT 5"
67+
"""
68+
)
69+
]
70+
71+
self.agent_executor = create_react_agent(self.llm, tools)
72+
73+
system_prompt = """
74+
You are an intelligent agent capable of data analysis. When users input their data analysis requirements,
75+
you need to first convert the requirements into executable SQL, then execute the SQL through tools to return results,
76+
and finally summarize the SQL query results. When all tasks are completed, you need to generate an HTML format data analysis report.
77+
78+
You can analyze requirements step by step to determine the final SQL query to generate.
79+
To improve SQL generation accuracy, please evaluate the accuracy of the SQL after generation,
80+
if there are issues, regenerate the SQL.
81+
When SQL execution fails, you need to correct the SQL based on the error message and try to execute again.
82+
83+
### Tools ###
84+
execute_sql: Can execute SQL by passing in SQL statements and return execution results
85+
"""
86+
user_prompt = """
87+
Below is the database information I need to query:
88+
{schema}
89+
90+
My requirement is: {question}
91+
"""
92+
# Define prompt template
93+
self.prompt = ChatPromptTemplate.from_messages([
94+
("system", system_prompt),
95+
("human", user_prompt)
96+
])
97+
98+
def generate_sql(self, question: str) -> str:
99+
chain = self.prompt | self.agent_executor
100+
schema = self.db.get_table_info()
101+
return chain.invoke({"schema": schema, "question": question})
102+
103+
async def async_generate(self, question: str) -> AsyncGenerator[str, None]:
104+
105+
chain = self.prompt | self.agent_executor
106+
schema = self.db.get_table_info()
107+
108+
async for chunk in chain.astream({"schema": schema, "question": question}):
109+
if not isinstance(chunk, dict):
110+
continue
111+
112+
if "agent" in chunk:
113+
messages = chunk["agent"].get("messages", [])
114+
for msg in messages:
115+
if tool_calls := msg.additional_kwargs.get("tool_calls"):
116+
for tool_call in tool_calls:
117+
response = {
118+
"type": "tool_call",
119+
"tool": tool_call["function"]["name"],
120+
"args": tool_call["function"]["arguments"]
121+
}
122+
yield f"data: {json.dumps(response, ensure_ascii=False)}\n\n"
123+
124+
if content := msg.content:
125+
html_start = content.find("```html")
126+
html_end = content.find("```", html_start + 6)
127+
if html_start != -1 and html_end != -1:
128+
html_content = content[html_start + 7:html_end].strip()
129+
response = {
130+
"type": "final",
131+
"content": content.split("```html")[0].strip(),
132+
"html": html_content
133+
}
134+
else:
135+
response = {
136+
"type": "final",
137+
"content": content
138+
}
139+
yield f"data: {json.dumps(response, ensure_ascii=False)}\n\n"
140+
141+
if "tools" in chunk:
142+
messages = chunk["tools"].get("messages", [])
143+
for msg in messages:
144+
response = {
145+
"type": "tool_result",
146+
"tool": msg.name,
147+
"content": msg.content
148+
}
149+
yield f"data: {json.dumps(response, ensure_ascii=False)}\n\n"
150+
151+
await asyncio.sleep(0.1)
152+
153+
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
154+
155+
def execute_sql(ds: CoreDatasource, sql: str) -> str:
156+
"""Execute SQL query
157+
158+
Args:
159+
ds: Data source instance
160+
sql: SQL query statement
161+
162+
Returns:
163+
Query results
164+
"""
165+
print(f"Executing SQL on ds_id {ds.id}: {sql}")
166+
return exec_sql(ds, sql)
167+
168+
def execute_sql_with_db(db: SQLDatabase, sql: str) -> str:
169+
"""Execute SQL query using SQLDatabase
170+
171+
Args:
172+
db: SQLDatabase instance
173+
sql: SQL query statement
174+
175+
Returns:
176+
str: Query results formatted as string
177+
"""
178+
try:
179+
# Execute query
180+
result = db.run(sql)
181+
182+
if not result:
183+
return "Query executed successfully but returned no results."
184+
185+
# Format results
186+
return str(result)
187+
188+
except Exception as e:
189+
error_msg = f"SQL execution failed: {str(e)}"
190+
logging.error(error_msg)
191+
raise RuntimeError(error_msg)

backend/apps/db/db.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,18 @@
55
from typing import Any
66
import json
77
from apps.datasource.utils.utils import aes_decrypt
8+
from common.core.deps import SessionDep
89

910

11+
def get_uri(ds: CoreDatasource):
12+
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
13+
db_url: str
14+
if ds.type == "mysql":
15+
db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{urllib.parse.quote(conf.database)}"
16+
else:
17+
raise 'The datasource type not support.'
18+
return db_url
19+
1020
def get_session(ds: CoreDatasource):
1121
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
1222
db_url: str
@@ -88,3 +98,21 @@ def exec_sql(ds: CoreDatasource, sql: str):
8898
result.close()
8999
if session is not None:
90100
session.close()
101+
102+
def exec_sql(ds: CoreDatasource, sql: str):
103+
ds = session.get(CoreDatasource, id)
104+
session = get_session(ds)
105+
result = session.execute(text(sql))
106+
try:
107+
columns = result.keys()._keys
108+
res = result.fetchall()
109+
result_list = [
110+
{columns[i]: value for i, value in enumerate(tuple_item)}
111+
for tuple_item in res
112+
]
113+
return {"fields": columns, "data": result_list}
114+
finally:
115+
if result is not None:
116+
result.close()
117+
if session is not None:
118+
session.close()

backend/pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ dependencies = [
1919
"sentry-sdk[fastapi]<2.0.0,>=1.40.6",
2020
"pyjwt<3.0.0,>=2.8.0",
2121
"pycryptodome (>=3.22.0,<4.0.0)",
22-
"langchain>=0.1.0,<0.2.0",
23-
"langchain-core>=0.1.10,<0.2.0",
24-
"langchain-openai>=0.1.0,<0.2.0",
25-
"langchain-community>=0.0.19,<0.1.0",
22+
"langchain>=0.3,<0.4",
23+
"langchain-core>=0.3,<0.4",
24+
"langchain-openai>=0.3,<0.4",
25+
"langchain-community>=0.3,<0.4",
26+
"langgraph>=0.3,<0.4",
2627
"vllm>=0.8.5",
2728
"dashscope>=1.14.0,<2.0.0",
2829
"sse-starlette>=1.8.0,<2.0.0",

frontend/src/api/chat.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ import { request } from '@/utils/request'
22

33
export const questionApi = {
44
pager: (pageNumber: number, pageSize: number) => request.get(`/chat/question/pager/${pageNumber}/${pageSize}`),
5-
// add: (data: any, progress: any) => request.post('/chat/question', data, { responseType: 'stream', onDownloadProgress: progress }),
6-
add: (data: any) => request.post('/chat/question', data),
5+
/* add: (data: any) => new Promise((resolve, reject) => {
6+
request.post('/chat/question', data, { responseType: 'stream', timeout: 0, onDownloadProgress: p => {
7+
resolve(p)
8+
}}).catch(e => reject(e))
9+
}), */
10+
// add: (data: any) => request.post('/chat/question', data),
11+
add: (data: any) => request.fetchStream('/chat/question', data),
712
edit: (data: any) => request.put('/chat/question', data),
813
delete: (id: number) => request.delete(`/chat/question/${id}`),
914
query: (id: number) => request.get(`/chat/question/${id}`)

0 commit comments

Comments
 (0)