Skip to content

Commit 8afb148

Browse files
committed
feat: chat question for mcp
1 parent cc9af61 commit 8afb148

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

backend/apps/chat/task/llm.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, List, Union, Dict
44

5+
import numpy as np
56
import orjson
67
import pandas as pd
78
from langchain_community.utilities import SQLDatabase
@@ -543,16 +544,16 @@ def execute_sql_with_db(db: SQLDatabase, sql: str) -> str:
543544
raise RuntimeError(error_msg)
544545

545546

546-
def run_task(llm_service: LLMService, session: SessionDep, stream: bool = True):
547+
def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True):
547548
try:
548549
# return id
549-
if stream:
550+
if in_chat:
550551
yield orjson.dumps({'type': 'id', 'id': llm_service.get_record().id}).decode() + '\n\n'
551552

552553
# select datasource if datasource is none
553554
if not llm_service.ds:
554555
ds_res = llm_service.select_datasource()
555-
if stream:
556+
if in_chat:
556557
for chunk in ds_res:
557558
yield orjson.dumps({'content': chunk, 'type': 'datasource-result'}).decode() + '\n\n'
558559
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
@@ -565,63 +566,81 @@ def run_task(llm_service: LLMService, session: SessionDep, stream: bool = True):
565566
full_sql_text = ''
566567
for chunk in sql_res:
567568
full_sql_text += chunk
568-
if stream:
569+
if in_chat:
569570
yield orjson.dumps({'content': chunk, 'type': 'sql-result'}).decode() + '\n\n'
570-
if stream:
571+
if in_chat:
571572
yield orjson.dumps({'type': 'info', 'msg': 'sql generated'}).decode() + '\n\n'
572573

573574
# filter sql
574575
print(full_sql_text)
575576
sql = llm_service.check_save_sql(res=full_sql_text)
576577
print(sql)
577-
if stream:
578+
if in_chat:
578579
yield orjson.dumps({'content': sql, 'type': 'sql'}).decode() + '\n\n'
580+
else:
581+
yield f'```sql\n{sql}\n```\n\n'
579582

580583
# execute sql
581584
result = llm_service.execute_sql(sql=sql)
582585
llm_service.save_sql_data(data_obj=result)
583-
if stream:
586+
if in_chat:
584587
yield orjson.dumps({'content': orjson.dumps(result).decode(), 'type': 'sql-data'}).decode() + '\n\n'
585588

586589
# generate chart
587590
chart_res = llm_service.generate_chart()
588591
full_chart_text = ''
589592
for chunk in chart_res:
590593
full_chart_text += chunk
591-
if stream:
594+
if in_chat:
592595
yield orjson.dumps({'content': chunk, 'type': 'chart-result'}).decode() + '\n\n'
593-
if stream:
596+
if in_chat:
594597
yield orjson.dumps({'type': 'info', 'msg': 'chart generated'}).decode() + '\n\n'
595598

596599
# filter chart
597600
print(full_chart_text)
598601
chart = llm_service.check_save_chart(res=full_chart_text)
599602
print(chart)
600-
if stream:
603+
if in_chat:
601604
yield orjson.dumps({'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n'
605+
else:
606+
data = []
607+
_fields = {}
608+
if chart.get('columns'):
609+
for _column in chart.get('columns'):
610+
if _column:
611+
_fields[_column.get('value')] = _column.get('name')
612+
if chart.get('axis'):
613+
if chart.get('axis').get('x'):
614+
_fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name')
615+
if chart.get('axis').get('y'):
616+
_fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name')
617+
if chart.get('axis').get('series'):
618+
_fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get('name')
619+
_fields_list = []
620+
_fields_skip = False
621+
for _data in result.get('data'):
622+
_row = []
623+
for field in result.get('fields'):
624+
_row.append(_data.get(field))
625+
if not _fields_skip:
626+
_fields_list.append(field if not _fields.get(field) else _fields.get(field))
627+
data.append(_row)
628+
_fields_skip = True
629+
df = pd.DataFrame(np.array(data), columns=_fields_list)
630+
markdown_table = df.to_markdown(index=False)
631+
yield markdown_table + '\n\n'
602632

603633
record = llm_service.finish()
604-
if stream:
634+
if in_chat:
605635
yield orjson.dumps({'type': 'finish'}).decode() + '\n\n'
606636
else:
607-
md_str = f'```sql\n{sql}\n```\n\n'
608637
# todo generate picture
609-
if chart['type'] == 'table':
610-
data = {}
611-
for _data in result['data']:
612-
for field in result['fields']:
613-
if not data[field]:
614-
data[field] = []
615-
data[field].append(_data[field])
616-
df = pd.DataFrame(data, columns=result['fields'])
617-
markdown_table = df.to_markdown(index=False)
618-
md_str += markdown_table
619-
else:
620-
md_str += ''
638+
if chart['type'] != 'table':
639+
yield '# todo generate chart picture'
621640

622641
except Exception as e:
623642
llm_service.save_error(message=str(e))
624-
if stream:
643+
if in_chat:
625644
yield orjson.dumps({'content': str(e), 'type': 'error'}).decode() + '\n\n'
626645
else:
627646
raise e

backend/apps/mcp/mcp.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Annotated
66

77
from fastapi import APIRouter, Depends, HTTPException
8+
from fastapi.responses import StreamingResponse
89
from fastapi.security import OAuth2PasswordRequestForm
910

1011
from apps.chat.api.chat import create_chat
@@ -59,22 +60,4 @@ async def mcp_question(session: SessionDep, chat: ChatMcp):
5960
llm_service = LLMService(session, user, chat)
6061
llm_service.init_record()
6162

62-
run_task(llm_service, session)
63-
64-
# return await stream_sql(session, user, chat)
65-
return {"content": """这是一段写死的测试内容:
66-
67-
步骤1: 确定需要查询的字段。
68-
我们需要统计上海的订单总数,因此需要从"城市"字段中筛选出值为"上海"的记录,并使用COUNT函数计算这些记录的数量。
69-
70-
步骤2: 确定筛选条件。
71-
问题要求统计上海的订单总数,所以我们需要在SQL语句中添加WHERE "城市" = '上海'来筛选出符合条件的记录。
72-
73-
步骤3: 避免关键字冲突。
74-
因为这个Excel/CSV数据库是 PostgreSQL 类型,所以在schema、表名、字段名和别名外层加双引号。
75-
76-
最终答案:
77-
```json
78-
{"success":true,"sql":"SELECT COUNT(*) AS \"TotalOrders\" FROM \"public\".\"Sheet1_c27345b66e\" WHERE \"城市\" = '上海';"}
79-
```
80-
<img src="https://sqlbot.fit2cloud.cn/images/111.png">"""}
63+
return StreamingResponse(run_task(llm_service, session, False), media_type="text/event-stream")

backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"oracledb (>=3.1.1,<4.0.0)",
3636
"pyyaml (>=6.0.2,<7.0.0)",
3737
"fastapi-mcp (>=0.3.4,<0.4.0)",
38+
"tabulate>=0.9.0",
3839
"sqlbot-xpack==0.0.3.1",
3940
]
4041
[[tool.uv.index]]

0 commit comments

Comments
 (0)