Skip to content

Commit 5c697f9

Browse files
committed
feat: MCP mcp/mcp_question support param to disable streaming
1 parent f52178f commit 5c697f9

File tree

5 files changed

+183
-92
lines changed

5 files changed

+183
-92
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class CreateChat(BaseModel):
136136
id: int = None
137137
question: str = None
138138
datasource: int = None
139-
origin: Optional[int] = 0
139+
origin: Optional[int] = 0 # 0是页面上,mcp是1,小助手是2
140140

141141

142142
class RenameChat(BaseModel):
@@ -246,6 +246,7 @@ class McpQuestion(BaseModel):
246246
question: str = Body(description='用户提问')
247247
chat_id: int = Body(description='会话ID')
248248
token: str = Body(description='token')
249+
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)
249250

250251

251252
class AxisObj(BaseModel):
@@ -264,3 +265,4 @@ class McpAssistant(BaseModel):
264265
question: str = Body(description='用户提问')
265266
url: str = Body(description='第三方数据接口')
266267
authorization: str = Body(description='第三方接口凭证')
268+
stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)

backend/apps/chat/task/llm.py

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -933,14 +933,17 @@ def await_result(self):
933933
break
934934
yield chunk
935935

936-
def run_task_async(self, in_chat: bool = True):
937-
self.future = executor.submit(self.run_task_cache, in_chat)
936+
def run_task_async(self, in_chat: bool = True, stream: bool = True):
937+
if in_chat:
938+
stream = True
939+
self.future = executor.submit(self.run_task_cache, in_chat, stream)
938940

939-
def run_task_cache(self, in_chat: bool = True):
940-
for chunk in self.run_task(in_chat):
941+
def run_task_cache(self, in_chat: bool = True, stream: bool = True):
942+
for chunk in self.run_task(in_chat, stream):
941943
self.chunk_list.append(chunk)
942944

943-
def run_task(self, in_chat: bool = True):
945+
def run_task(self, in_chat: bool = True, stream: bool = True):
946+
json_result = {'success': True}
944947
try:
945948
if self.ds:
946949
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
@@ -955,6 +958,8 @@ def run_task(self, in_chat: bool = True):
955958
# return id
956959
if in_chat:
957960
yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n'
961+
if not stream:
962+
json_result['record_id'] = self.get_record().id
958963

959964
# return title
960965
if self.change_title:
@@ -964,8 +969,10 @@ def run_task(self, in_chat: bool = True):
964969
brief=self.chat_question.question.strip()[:20]))
965970
if in_chat:
966971
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
972+
if not stream:
973+
json_result['title'] = brief
967974

968-
# select datasource if datasource is none
975+
# select datasource if datasource is none
969976
if not self.ds:
970977
ds_res = self.select_datasource()
971978

@@ -1003,7 +1010,6 @@ def run_task(self, in_chat: bool = True):
10031010
'type': 'sql-result'}).decode() + '\n\n'
10041011
if in_chat:
10051012
yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'sql generated'}).decode() + '\n\n'
1006-
10071013
# filter sql
10081014
SQLBotLogUtil.info(full_sql_text)
10091015

@@ -1039,11 +1045,16 @@ def run_task(self, in_chat: bool = True):
10391045
sql = self.check_save_sql(res=full_sql_text)
10401046

10411047
SQLBotLogUtil.info(sql)
1048+
1049+
if not stream:
1050+
json_result['sql'] = sql
1051+
10421052
format_sql = sqlparse.format(sql, reindent=True)
10431053
if in_chat:
10441054
yield 'data:' + orjson.dumps({'content': format_sql, 'type': 'sql'}).decode() + '\n\n'
10451055
else:
1046-
yield f'```sql\n{format_sql}\n```\n\n'
1056+
if stream:
1057+
yield f'```sql\n{format_sql}\n```\n\n'
10471058

10481059
# execute sql
10491060
real_execute_sql = sql
@@ -1058,6 +1069,8 @@ def run_task(self, in_chat: bool = True):
10581069
self.save_sql_data(data_obj=result)
10591070
if in_chat:
10601071
yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n'
1072+
if not stream:
1073+
json_result['data'] = result.get('data')
10611074

10621075
# generate chart
10631076
chart_res = self.generate_chart(chart_type)
@@ -1075,41 +1088,46 @@ def run_task(self, in_chat: bool = True):
10751088
SQLBotLogUtil.info(full_chart_text)
10761089
chart = self.check_save_chart(res=full_chart_text)
10771090
SQLBotLogUtil.info(chart)
1091+
1092+
if not stream:
1093+
json_result['chart'] = chart
1094+
10781095
if in_chat:
10791096
yield 'data:' + orjson.dumps(
10801097
{'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n'
10811098
else:
1082-
data = []
1083-
_fields = {}
1084-
if chart.get('columns'):
1085-
for _column in chart.get('columns'):
1086-
if _column:
1087-
_fields[_column.get('value')] = _column.get('name')
1088-
if chart.get('axis'):
1089-
if chart.get('axis').get('x'):
1090-
_fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name')
1091-
if chart.get('axis').get('y'):
1092-
_fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name')
1093-
if chart.get('axis').get('series'):
1094-
_fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get(
1095-
'name')
1096-
_fields_list = []
1097-
_fields_skip = False
1098-
for _data in result.get('data'):
1099-
_row = []
1100-
for field in result.get('fields'):
1101-
_row.append(_data.get(field))
1102-
if not _fields_skip:
1103-
_fields_list.append(field if not _fields.get(field) else _fields.get(field))
1104-
data.append(_row)
1105-
_fields_skip = True
1106-
1107-
if not data or not _fields_list:
1108-
yield 'The SQL execution result is empty.\n\n'
1109-
else:
1110-
df = pd.DataFrame(np.array(data), columns=_fields_list)
1111-
markdown_table = df.to_markdown(index=False)
1112-
yield markdown_table + '\n\n'
1099+
if stream:
1100+
data = []
1101+
_fields = {}
1102+
if chart.get('columns'):
1103+
for _column in chart.get('columns'):
1104+
if _column:
1105+
_fields[_column.get('value')] = _column.get('name')
1106+
if chart.get('axis'):
1107+
if chart.get('axis').get('x'):
1108+
_fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name')
1109+
if chart.get('axis').get('y'):
1110+
_fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name')
1111+
if chart.get('axis').get('series'):
1112+
_fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get(
1113+
'name')
1114+
_fields_list = []
1115+
_fields_skip = False
1116+
for _data in result.get('data'):
1117+
_row = []
1118+
for field in result.get('fields'):
1119+
_row.append(_data.get(field))
1120+
if not _fields_skip:
1121+
_fields_list.append(field if not _fields.get(field) else _fields.get(field))
1122+
data.append(_row)
1123+
_fields_skip = True
1124+
1125+
if not data or not _fields_list:
1126+
yield 'The SQL execution result is empty.\n\n'
1127+
else:
1128+
df = pd.DataFrame(np.array(data), columns=_fields_list)
1129+
markdown_table = df.to_markdown(index=False)
1130+
yield markdown_table + '\n\n'
11131131

11141132
if in_chat:
11151133
yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
@@ -1119,7 +1137,14 @@ def run_task(self, in_chat: bool = True):
11191137
yield '### generated chart picture\n\n'
11201138
image_url = request_picture(self.record.chat_id, self.record.id, chart, result)
11211139
SQLBotLogUtil.info(image_url)
1122-
yield f'![{chart["type"]}]({image_url})'
1140+
if stream:
1141+
yield f'![{chart["type"]}]({image_url})'
1142+
else:
1143+
json_result['image_url'] = image_url
1144+
1145+
if not stream:
1146+
yield json_result
1147+
11231148
except Exception as e:
11241149
traceback.print_exc()
11251150
error_msg: str
@@ -1137,7 +1162,12 @@ def run_task(self, in_chat: bool = True):
11371162
if in_chat:
11381163
yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n'
11391164
else:
1140-
yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。'
1165+
if stream:
1166+
yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。'
1167+
else:
1168+
json_result['success'] = False
1169+
json_result['message'] = error_msg
1170+
yield json_result
11411171
finally:
11421172
self.finish()
11431173

backend/apps/mcp/mcp.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Author: Junjun
22
# Date: 2025/7/1
33
import json
4+
import traceback
45
from datetime import timedelta
56

67
import jwt
@@ -10,6 +11,7 @@
1011
from jwt.exceptions import InvalidTokenError
1112
from pydantic import ValidationError
1213
from sqlmodel import select
14+
from starlette.responses import JSONResponse
1315

1416
from apps.chat.api.chat import create_chat
1517
from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion
@@ -106,11 +108,37 @@ async def mcp_question(session: SessionDep, chat: McpQuestion):
106108
raise HTTPException(status_code=400, detail="Inactive user")
107109

108110
mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question)
109-
# ask
110-
llm_service = await LLMService.create(session_user, mcp_chat)
111-
llm_service.init_record()
112111

113-
return StreamingResponse(llm_service.run_task(False), media_type="text/event-stream")
112+
try:
113+
llm_service = await LLMService.create(session_user, mcp_chat)
114+
llm_service.init_record()
115+
llm_service.run_task_async(False, chat.stream)
116+
except Exception as e:
117+
traceback.print_exc()
118+
119+
if chat.stream:
120+
def _err(_e: Exception):
121+
yield str(_e) + '\n\n'
122+
123+
return StreamingResponse(_err(e), media_type="text/event-stream")
124+
else:
125+
return {'message': str(e)}
126+
if chat.stream:
127+
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
128+
else:
129+
res = llm_service.await_result()
130+
raw_data = {}
131+
for chunk in res:
132+
if chunk:
133+
raw_data = chunk
134+
status_code = 200
135+
if not raw_data.get('success'):
136+
status_code = 500
137+
138+
return JSONResponse(
139+
content=raw_data,
140+
status_code=status_code,
141+
)
114142

115143

116144
@router.post("/mcp_assistant", operation_id="mcp_assistant")

0 commit comments

Comments
 (0)