Skip to content

Commit 07b90b9

Browse files
committed
feat: MCP mcp/mcp_assistant support param to disable streaming
1 parent d945ab6 commit 07b90b9

File tree

3 files changed

+88
-12
lines changed

3 files changed

+88
-12
lines changed

backend/apps/chat/models/chat_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class OperationEnum(Enum):
4040
CHOOSE_DATASOURCE = '6'
4141
GENERATE_DYNAMIC_SQL = '7'
4242

43+
class ChatFinishStep(Enum):
44+
GENERATE_SQL = 1
45+
QUERY_DATA = 2
46+
GENERATE_CHART = 3
4347

4448
# TODO choose table / check connection / generate description
4549

backend/apps/chat/task/llm.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
2929
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
3030
get_last_execute_sql_error
31-
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum
31+
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
32+
ChatFinishStep
3233
from apps.data_training.curd.data_training import get_training_template
3334
from apps.datasource.crud.datasource import get_table_schema
3435
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
@@ -934,17 +935,20 @@ def await_result(self):
934935
break
935936
yield chunk
936937

937-
def run_task_async(self, in_chat: bool = True, stream: bool = True):
938+
def run_task_async(self, in_chat: bool = True, stream: bool = True,
939+
finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART):
938940
if in_chat:
939941
stream = True
940-
self.future = executor.submit(self.run_task_cache, in_chat, stream)
942+
self.future = executor.submit(self.run_task_cache, in_chat, stream, finish_step)
941943

942-
def run_task_cache(self, in_chat: bool = True, stream: bool = True):
943-
for chunk in self.run_task(in_chat, stream):
944+
def run_task_cache(self, in_chat: bool = True, stream: bool = True,
945+
finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART):
946+
for chunk in self.run_task(in_chat, stream, finish_step):
944947
self.chunk_list.append(chunk)
945948

946-
def run_task(self, in_chat: bool = True, stream: bool = True):
947-
json_result = {'success': True}
949+
def run_task(self, in_chat: bool = True, stream: bool = True,
950+
finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART):
951+
json_result: Dict[str, Any] = {'success': True}
948952
try:
949953
if self.ds:
950954
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
@@ -1066,13 +1070,47 @@ def run_task(self, in_chat: bool = True, stream: bool = True):
10661070
subsql)
10671071
real_execute_sql = assistant_dynamic_sql
10681072

1073+
if finish_step.value <= ChatFinishStep.GENERATE_SQL.value:
1074+
if in_chat:
1075+
yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
1076+
if not stream:
1077+
yield json_result
1078+
return
1079+
10691080
result = self.execute_sql(sql=real_execute_sql)
10701081
self.save_sql_data(data_obj=result)
10711082
if in_chat:
10721083
yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n'
10731084
if not stream:
10741085
json_result['data'] = result.get('data')
10751086

1087+
if finish_step.value <= ChatFinishStep.QUERY_DATA.value:
1088+
if stream:
1089+
if in_chat:
1090+
yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
1091+
else:
1092+
data = []
1093+
_fields_list = []
1094+
_fields_skip = False
1095+
for _data in result.get('data'):
1096+
_row = []
1097+
for field in result.get('fields'):
1098+
_row.append(_data.get(field))
1099+
if not _fields_skip:
1100+
_fields_list.append(field)
1101+
data.append(_row)
1102+
_fields_skip = True
1103+
1104+
if not data or not _fields_list:
1105+
yield 'The SQL execution result is empty.\n\n'
1106+
else:
1107+
df = pd.DataFrame(np.array(data), columns=_fields_list)
1108+
markdown_table = df.to_markdown(index=False)
1109+
yield markdown_table + '\n\n'
1110+
else:
1111+
yield json_result
1112+
return
1113+
10761114
# generate chart
10771115
chart_res = self.generate_chart(chart_type)
10781116
full_chart_text = ''

backend/apps/mcp/mcp.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from starlette.responses import JSONResponse
1515

1616
from apps.chat.api.chat import create_chat
17-
from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion
17+
from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion, \
18+
ChatFinishStep
1819
from apps.chat.task.llm import LLMService
1920
from apps.system.crud.user import authenticate
2021
from apps.system.crud.user import get_db_user
@@ -122,7 +123,10 @@ def _err(_e: Exception):
122123

123124
return StreamingResponse(_err(e), media_type="text/event-stream")
124125
else:
125-
return {'message': str(e)}
126+
return JSONResponse(
127+
content={'message': str(e)},
128+
status_code=500,
129+
)
126130
if chat.stream:
127131
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
128132
else:
@@ -162,6 +166,36 @@ async def mcp_assistant(session: SessionDep, chat: McpAssistant):
162166
# assistant question
163167
mcp_chat = ChatQuestion(chat_id=c.id, question=chat.question)
164168
# ask
165-
llm_service = await LLMService.create(session_user, mcp_chat, mcp_assistant_header)
166-
llm_service.init_record()
167-
return llm_service.run_task(False)
169+
try:
170+
llm_service = await LLMService.create(session_user, mcp_chat, mcp_assistant_header)
171+
llm_service.init_record()
172+
llm_service.run_task_async(False, chat.stream, ChatFinishStep.QUERY_DATA)
173+
except Exception as e:
174+
traceback.print_exc()
175+
176+
if chat.stream:
177+
def _err(_e: Exception):
178+
yield str(_e) + '\n\n'
179+
180+
return StreamingResponse(_err(e), media_type="text/event-stream")
181+
else:
182+
return JSONResponse(
183+
content={'message': str(e)},
184+
status_code=500,
185+
)
186+
if chat.stream:
187+
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
188+
else:
189+
res = llm_service.await_result()
190+
raw_data = {}
191+
for chunk in res:
192+
if chunk:
193+
raw_data = chunk
194+
status_code = 200
195+
if not raw_data.get('success'):
196+
status_code = 500
197+
198+
return JSONResponse(
199+
content=raw_data,
200+
status_code=status_code,
201+
)

0 commit comments

Comments
 (0)