Skip to content

Commit dd594c2

Browse files
Merge branch 'main' into dev
2 parents a489001 + ea92686 commit dd594c2

File tree

12 files changed

+287
-227
lines changed

12 files changed

+287
-227
lines changed

.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,5 @@ SENTRY_DSN=
2626
# Configure these with your own Docker registry images
2727
DOCKER_IMAGE_BACKEND=backend
2828
DOCKER_IMAGE_FRONTEND=frontend
29+
30+
MCP_IMAGE_PATH=/opt/sqlbot/images

backend/apps/api.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from fastapi import APIRouter
22

3-
from apps.system.api import login, user, aimodel
4-
from apps.settings.api import terminology
5-
from apps.datasource.api import datasource
63
from apps.chat.api import chat
74
from apps.dashboard.api import dashboard_api
5+
from apps.datasource.api import datasource
6+
from apps.settings.api import terminology
7+
from apps.system.api import login, user, aimodel
8+
from apps.mcp import mcp
89

910
api_router = APIRouter()
1011
api_router.include_router(login.router)
1112
api_router.include_router(user.router)
1213
api_router.include_router(aimodel.router)
1314
api_router.include_router(terminology.router)
1415
api_router.include_router(datasource.router)
16+
# api_router.include_router(row_permission.router)
17+
# api_router.include_router(column_permission.router)
1518
api_router.include_router(chat.router)
1619
api_router.include_router(dashboard_api.router)
20+
api_router.include_router(mcp.router)
21+

backend/apps/chat/api/chat.py

Lines changed: 16 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
11
import traceback
2-
from typing import List
32

43
import orjson
54
from fastapi import APIRouter, HTTPException
65
from fastapi.responses import StreamingResponse
7-
from sqlmodel import select
86

97
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
10-
delete_chat, list_records
11-
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat, ChatQuestion, ChatMcp
12-
from apps.chat.task.llm import LLMService
13-
from apps.datasource.crud.datasource import get_table_schema
14-
from apps.datasource.models.datasource import CoreDatasource
15-
from apps.system.crud.user import get_user_info
16-
from apps.system.models.system_model import AiModelDetail
17-
from common.core.deps import SessionDep, CurrentUser, get_current_user
8+
delete_chat
9+
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion
10+
from apps.chat.task.llm import LLMService, run_task
11+
from common.core.deps import SessionDep, CurrentUser
1812

1913
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
2014

@@ -57,34 +51,6 @@ async def delete(session: SessionDep, chart_id: int):
5751
)
5852

5953

60-
@router.post("/mcp_start", operation_id="mcp_start")
61-
async def mcp_start(session: SessionDep, chat: ChatMcp):
62-
user = await get_current_user(session, chat.token)
63-
return create_chat(session, user, CreateChat(), False)
64-
65-
66-
@router.post("/mcp_question", operation_id="mcp_question")
67-
async def mcp_question(session: SessionDep, chat: ChatMcp):
68-
user = await get_current_user(session, chat.token)
69-
# return await stream_sql(session, user, chat)
70-
return {"content": """这是一段写死的测试内容:
71-
72-
步骤1: 确定需要查询的字段。
73-
我们需要统计上海的订单总数,因此需要从"城市"字段中筛选出值为"上海"的记录,并使用COUNT函数计算这些记录的数量。
74-
75-
步骤2: 确定筛选条件。
76-
问题要求统计上海的订单总数,所以我们需要在SQL语句中添加WHERE "城市" = '上海'来筛选出符合条件的记录。
77-
78-
步骤3: 避免关键字冲突。
79-
因为这个Excel/CSV数据库是 PostgreSQL 类型,所以在schema、表名、字段名和别名外层加双引号。
80-
81-
最终答案:
82-
```json
83-
{"success":true,"sql":"SELECT COUNT(*) AS \"TotalOrders\" FROM \"public\".\"Sheet1_c27345b66e\" WHERE \"城市\" = '上海';"}
84-
```
85-
<img src="https://sqlbot.fit2cloud.cn/images/111.png">"""}
86-
87-
8854
@router.post("/start")
8955
async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat):
9056
try:
@@ -96,7 +62,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
9662
)
9763

9864

99-
@router.post("/question", operation_id="question")
65+
@router.post("/question")
10066
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion):
10167
"""Stream SQL analysis results
10268
@@ -109,107 +75,17 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
10975
Streaming response with analysis results
11076
"""
11177

112-
chat = session.query(Chat).filter(Chat.id == request_question.chat_id).first()
113-
if not chat:
114-
raise HTTPException(
115-
status_code=400,
116-
detail=f"Chat with id {request_question.chat_id} not found"
117-
)
118-
ds: CoreDatasource | None = None
119-
if chat.datasource:
120-
# Get available datasource
121-
ds = session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first()
122-
if not ds:
123-
raise HTTPException(
124-
status_code=500,
125-
detail="No available datasource configuration found"
126-
)
127-
128-
request_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL'
129-
130-
# Get available AI model
131-
aimodel = session.exec(select(AiModelDetail).where(
132-
AiModelDetail.status == True,
133-
AiModelDetail.api_key.is_not(None)
134-
)).first()
135-
if not aimodel:
78+
try:
79+
llm_service = LLMService(session, current_user, request_question)
80+
llm_service.init_record()
81+
except Exception as e:
82+
traceback.print_exc()
13683
raise HTTPException(
13784
status_code=500,
138-
detail="No available AI model configuration found"
85+
detail=str(e)
13986
)
14087

141-
history_records: List[ChatRecord] = list(filter(lambda r: True if r.first_chat != True else False,
142-
list_records(session=session, current_user=current_user,
143-
chart_id=request_question.chat_id)))
144-
# get schema
145-
if ds:
146-
request_question.db_schema = get_table_schema(session=session, ds=ds)
147-
148-
db_user = get_user_info(session=session, user_id=current_user.id)
149-
request_question.lang = db_user.language
150-
151-
llm_service = LLMService(request_question, aimodel, history_records,
152-
CoreDatasource(**ds.model_dump()) if ds else None)
153-
154-
llm_service.init_record(session=session, current_user=current_user)
155-
156-
def run_task():
157-
try:
158-
# return id
159-
yield orjson.dumps({'type': 'id', 'id': llm_service.get_record().id}).decode() + '\n\n'
160-
161-
# select datasource if datasource is none
162-
if not ds:
163-
ds_res = llm_service.select_datasource(session=session)
164-
for chunk in ds_res:
165-
yield orjson.dumps({'content': chunk, 'type': 'datasource-result'}).decode() + '\n\n'
166-
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
167-
'engine_type': llm_service.ds.type_name, 'type': 'datasource'}).decode() + '\n\n'
168-
169-
llm_service.chat_question.db_schema = get_table_schema(session=session, ds=llm_service.ds)
170-
171-
# generate sql
172-
sql_res = llm_service.generate_sql(session=session)
173-
full_sql_text = ''
174-
for chunk in sql_res:
175-
full_sql_text += chunk
176-
yield orjson.dumps({'content': chunk, 'type': 'sql-result'}).decode() + '\n\n'
177-
yield orjson.dumps({'type': 'info', 'msg': 'sql generated'}).decode() + '\n\n'
178-
179-
# filter sql
180-
print(full_sql_text)
181-
sql = llm_service.check_save_sql(session=session, res=full_sql_text)
182-
print(sql)
183-
yield orjson.dumps({'content': sql, 'type': 'sql'}).decode() + '\n\n'
184-
185-
# execute sql
186-
result = llm_service.execute_sql(sql=sql)
187-
llm_service.save_sql_data(session=session, data_obj=result)
188-
yield orjson.dumps({'content': orjson.dumps(result).decode(), 'type': 'sql-data'}).decode() + '\n\n'
189-
190-
# generate chart
191-
chart_res = llm_service.generate_chart(session=session)
192-
full_chart_text = ''
193-
for chunk in chart_res:
194-
full_chart_text += chunk
195-
yield orjson.dumps({'content': chunk, 'type': 'chart-result'}).decode() + '\n\n'
196-
yield orjson.dumps({'type': 'info', 'msg': 'chart generated'}).decode() + '\n\n'
197-
198-
# filter chart
199-
print(full_chart_text)
200-
chart = llm_service.check_save_chart(session=session, res=full_chart_text)
201-
print(chart)
202-
yield orjson.dumps({'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n'
203-
204-
llm_service.finish(session=session)
205-
yield orjson.dumps({'type': 'finish'}).decode() + '\n\n'
206-
207-
except Exception as e:
208-
traceback.print_exc()
209-
llm_service.save_error(session=session, message=str(e))
210-
yield orjson.dumps({'content': str(e), 'type': 'error'}).decode() + '\n\n'
211-
212-
return StreamingResponse(run_task(), media_type="text/event-stream")
88+
return StreamingResponse(run_task(llm_service, session), media_type="text/event-stream")
21389

21490

21591
@router.post("/record/{chart_record_id}/{action_type}")
@@ -233,35 +109,9 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
233109
detail=f"Chat record with id {chart_record_id} has not generated chart, do not support to analyze it"
234110
)
235111

236-
chat = session.query(Chat).filter(Chat.id == record.chat_id).first()
237-
if not chat:
238-
raise HTTPException(
239-
status_code=400,
240-
detail=f"Chat with id {record.chart_id} not found"
241-
)
242-
243-
if chat.create_by != current_user.id:
244-
raise HTTPException(
245-
status_code=401,
246-
detail=f"You cannot use the chat with id {record.chart_id}"
247-
)
248-
249-
# Get available AI model
250-
aimodel = session.exec(select(AiModelDetail).where(
251-
AiModelDetail.status == True,
252-
AiModelDetail.api_key.is_not(None)
253-
)).first()
254-
if not aimodel:
255-
raise HTTPException(
256-
status_code=500,
257-
detail="No available AI model configuration found"
258-
)
259-
260-
request_question = ChatQuestion(chat_id=chat.id, question='')
261-
db_user = get_user_info(session=session, user_id=current_user.id)
262-
request_question.lang = db_user.language
112+
request_question = ChatQuestion(chat_id=record.chat_id, question='')
263113

264-
llm_service = LLMService(request_question, aimodel)
114+
llm_service = LLMService(session, current_user, request_question)
265115
llm_service.set_record(record)
266116

267117
def run_task():
@@ -277,14 +127,14 @@ def run_task():
277127

278128
elif action_type == 'predict':
279129
# generate predict
280-
analysis_res = llm_service.generate_predict(session=session)
130+
analysis_res = llm_service.generate_predict()
281131
full_text = ''
282132
for chunk in analysis_res:
283133
yield orjson.dumps({'content': chunk, 'type': 'predict-result'}).decode() + '\n\n'
284134
full_text += chunk
285135
yield orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n'
286136

287-
_data = llm_service.check_save_predict_data(session=session, res=full_text)
137+
_data = llm_service.check_save_predict_data(res=full_text)
288138
yield orjson.dumps({'type': 'predict', 'content': _data}).decode() + '\n\n'
289139

290140
yield orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n'

0 commit comments

Comments
 (0)