Skip to content

Commit 09b82e9

Browse files
committed
feat: support command in mcp
1 parent 7668a6f commit 09b82e9

File tree

5 files changed

+241
-144
lines changed

5 files changed

+241
-144
lines changed

backend/apps/chat/api/chat.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from fastapi import APIRouter, HTTPException, Path
99
from fastapi.responses import StreamingResponse
1010
from sqlalchemy import and_, select
11+
from starlette.responses import JSONResponse
1112

1213
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
1314
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
1415
format_json_data, format_json_list_data, get_chart_config, list_recent_questions
1516
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand, \
16-
ChatInfo, Chat
17+
ChatInfo, Chat, ChatFinishStep
1718
from apps.chat.task.llm import LLMService
1819
from apps.swagger.i18n import PLACEHOLDER_PREFIX
1920
from apps.system.schemas.permission import SqlbotPermission, require_permissions
@@ -166,11 +167,18 @@ def find_base_question(record_id: int, session: SessionDep):
166167
@require_permissions(permission=SqlbotPermission(type='chat', keyExpression="request_question.chat_id"))
167168
async def question_answer(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
168169
current_assistant: CurrentAssistant):
170+
return await question_answer_inner(session, current_user, request_question, current_assistant, embedding=True)
171+
172+
173+
async def question_answer_inner(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
174+
current_assistant: Optional[CurrentAssistant] = None, in_chat: bool = True,
175+
stream: bool = True,
176+
finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART, embedding: bool = False):
169177
try:
170178
command, text_before_command, record_id, warning_info = parse_quick_command(request_question.question)
171179
if command:
172-
# todo 暂不支持分析和预测,需要改造前端
173-
if command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA:
180+
# todo 对话界面下,暂不支持分析和预测,需要改造前端
181+
if in_chat and (command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA):
174182
raise Exception(f'Command: {command.value} temporary not supported')
175183

176184
if record_id is not None:
@@ -221,53 +229,83 @@ async def question_answer(session: SessionDep, current_user: CurrentUser, reques
221229
if command == QuickCommand.REGENERATE:
222230
request_question.question = text_before_command
223231
request_question.regenerate_record_id = rec_id
224-
return await stream_sql(session, current_user, request_question, current_assistant)
232+
return await stream_sql(session, current_user, request_question, current_assistant, in_chat, stream,
233+
finish_step, embedding)
225234

226235
elif command == QuickCommand.ANALYSIS:
227-
return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant)
236+
return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant, in_chat, stream)
228237

229238
elif command == QuickCommand.PREDICT_DATA:
230-
return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant)
239+
return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant, in_chat, stream)
231240
else:
232241
raise Exception(f'Unknown command: {command.value}')
233242
else:
234-
return await stream_sql(session, current_user, request_question, current_assistant)
243+
return await stream_sql(session, current_user, request_question, current_assistant, in_chat, stream,
244+
finish_step, embedding)
235245
except Exception as e:
236246
traceback.print_exc()
237247

238-
def _err(_e: Exception):
239-
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
248+
if stream:
249+
def _err(_e: Exception):
250+
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
240251

241-
return StreamingResponse(_err(e), media_type="text/event-stream")
252+
return StreamingResponse(_err(e), media_type="text/event-stream")
253+
else:
254+
return JSONResponse(
255+
content={'message': str(e)},
256+
status_code=500,
257+
)
242258

243259

244260
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
245-
current_assistant: CurrentAssistant):
261+
current_assistant: Optional[CurrentAssistant] = None, in_chat: bool = True, stream: bool = True,
262+
finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART, embedding: bool = False):
246263
try:
247264
llm_service = await LLMService.create(session, current_user, request_question, current_assistant,
248-
embedding=True)
265+
embedding=embedding)
249266
llm_service.init_record(session=session)
250-
llm_service.run_task_async()
267+
llm_service.run_task_async(in_chat=in_chat, stream=stream, finish_step=finish_step)
251268
except Exception as e:
252269
traceback.print_exc()
253270

254-
def _err(_e: Exception):
255-
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
256-
257-
return StreamingResponse(_err(e), media_type="text/event-stream")
271+
if stream:
272+
def _err(_e: Exception):
273+
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
258274

259-
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
275+
return StreamingResponse(_err(e), media_type="text/event-stream")
276+
else:
277+
return JSONResponse(
278+
content={'message': str(e)},
279+
status_code=500,
280+
)
281+
if stream:
282+
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
283+
else:
284+
res = llm_service.await_result()
285+
raw_data = {}
286+
for chunk in res:
287+
if chunk:
288+
raw_data = chunk
289+
status_code = 200
290+
if not raw_data.get('success'):
291+
status_code = 500
292+
293+
return JSONResponse(
294+
content=raw_data,
295+
status_code=status_code,
296+
)
260297

261298

262299
@router.post("/record/{chat_record_id}/{action_type}", summary=f"{PLACEHOLDER_PREFIX}analysis_or_predict")
263300
async def analysis_or_predict_question(session: SessionDep, current_user: CurrentUser,
264301
current_assistant: CurrentAssistant, chat_record_id: int,
265-
action_type: str = Path(..., description=f"{PLACEHOLDER_PREFIX}analysis_or_predict_action_type")):
302+
action_type: str = Path(...,
303+
description=f"{PLACEHOLDER_PREFIX}analysis_or_predict_action_type")):
266304
return await analysis_or_predict(session, current_user, chat_record_id, action_type, current_assistant)
267305

268306

269307
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str,
270-
current_assistant: CurrentAssistant):
308+
current_assistant: CurrentAssistant, in_chat: bool = True, stream: bool = True):
271309
try:
272310
if action_type != 'analysis' and action_type != 'predict':
273311
raise Exception(f"Type {action_type} Not Found")
@@ -294,16 +332,35 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
294332
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question)
295333

296334
llm_service = await LLMService.create(session, current_user, request_question, current_assistant)
297-
llm_service.run_analysis_or_predict_task_async(session, action_type, record)
335+
llm_service.run_analysis_or_predict_task_async(session, action_type, record, in_chat, stream)
298336
except Exception as e:
299337
traceback.print_exc()
338+
if stream:
339+
def _err(_e: Exception):
340+
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
300341

301-
def _err(_e: Exception):
302-
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
303-
304-
return StreamingResponse(_err(e), media_type="text/event-stream")
305-
306-
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
342+
return StreamingResponse(_err(e), media_type="text/event-stream")
343+
else:
344+
return JSONResponse(
345+
content={'message': str(e)},
346+
status_code=500,
347+
)
348+
if stream:
349+
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
350+
else:
351+
res = llm_service.await_result()
352+
raw_data = {}
353+
for chunk in res:
354+
if chunk:
355+
raw_data = chunk
356+
status_code = 200
357+
if not raw_data.get('success'):
358+
status_code = 500
359+
360+
return JSONResponse(
361+
content=raw_data,
362+
status_code=status_code,
363+
)
307364

308365

309366
@router.get("/record/{chat_record_id}/excel/export", summary=f"{PLACEHOLDER_PREFIX}export_chart_data")

backend/apps/chat/curd/chat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,17 @@ def format_json_list_data(origin_data: list[dict]):
163163
return data
164164

165165

166+
def get_chat_chart_config(session: SessionDep, chat_record_id: int):
167+
stmt = select(ChatRecord.chart).where(and_(ChatRecord.id == chat_record_id))
168+
res = session.execute(stmt)
169+
for row in res:
170+
try:
171+
return orjson.loads(row.data)
172+
except Exception:
173+
pass
174+
return {}
175+
176+
166177
def get_chat_chart_data(session: SessionDep, chat_record_id: int):
167178
stmt = select(ChatRecord.data).where(and_(ChatRecord.id == chat_record_id))
168179
res = session.execute(stmt)

0 commit comments

Comments
 (0)