Skip to content

Commit a1bf97a

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 151daf6 + 1d5e777 commit a1bf97a

File tree

18 files changed

+444
-95
lines changed

18 files changed

+444
-95
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""016_modify_chat
2+
3+
Revision ID: 031148da1d81
4+
Revises: 02d84523a979
5+
Create Date: 2025-06-26 17:00:07.054531
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '031148da1d81'
15+
down_revision = '02d84523a979'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.alter_column('chat', 'datasource',
23+
existing_type=sa.INTEGER(),
24+
nullable=True)
25+
op.add_column('chat_record', sa.Column('ai_modal_id', sa.Integer(), nullable=True))
26+
op.add_column('chat_record', sa.Column('first_chat', sa.Boolean(), nullable=True))
27+
op.add_column('chat_record', sa.Column('recommended_question_answer', sa.Text(), nullable=True))
28+
op.add_column('chat_record', sa.Column('recommended_question', sa.Text(), nullable=True))
29+
op.add_column('chat_record', sa.Column('datasource_select_answer', sa.Text(), nullable=True))
30+
op.add_column('chat_record', sa.Column('token_sql', sa.Integer(), nullable=True))
31+
op.add_column('chat_record', sa.Column('token_chart', sa.Integer(), nullable=True))
32+
op.add_column('chat_record', sa.Column('token_analysis', sa.Integer(), nullable=True))
33+
op.add_column('chat_record', sa.Column('token_predict', sa.Integer(), nullable=True))
34+
op.add_column('chat_record', sa.Column('full_recommended_question_message', sa.Text(), nullable=True))
35+
op.add_column('chat_record', sa.Column('token_recommended_question', sa.Integer(), nullable=True))
36+
op.add_column('chat_record', sa.Column('full_select_datasource_message', sa.Text(), nullable=True))
37+
op.add_column('chat_record', sa.Column('token_select_datasource_question', sa.Integer(), nullable=True))
38+
op.alter_column('chat_record', 'chat_id',
39+
existing_type=sa.INTEGER(),
40+
nullable=False)
41+
op.alter_column('chat_record', 'datasource',
42+
existing_type=sa.INTEGER(),
43+
nullable=True)
44+
# ### end Alembic commands ###
45+
46+
47+
def downgrade():
48+
# ### commands auto generated by Alembic - please adjust! ###
49+
op.alter_column('chat_record', 'datasource',
50+
existing_type=sa.INTEGER(),
51+
nullable=False)
52+
op.alter_column('chat_record', 'chat_id',
53+
existing_type=sa.INTEGER(),
54+
nullable=True)
55+
op.drop_column('chat_record', 'token_select_datasource_question')
56+
op.drop_column('chat_record', 'full_select_datasource_message')
57+
op.drop_column('chat_record', 'token_recommended_question')
58+
op.drop_column('chat_record', 'full_recommended_question_message')
59+
op.drop_column('chat_record', 'token_predict')
60+
op.drop_column('chat_record', 'token_analysis')
61+
op.drop_column('chat_record', 'token_chart')
62+
op.drop_column('chat_record', 'token_sql')
63+
op.drop_column('chat_record', 'datasource_select_answer')
64+
op.drop_column('chat_record', 'recommended_question')
65+
op.drop_column('chat_record', 'recommended_question_answer')
66+
op.drop_column('chat_record', 'first_chat')
67+
op.drop_column('chat_record', 'ai_modal_id')
68+
op.alter_column('chat', 'datasource',
69+
existing_type=sa.INTEGER(),
70+
nullable=False)
71+
# ### end Alembic commands ###
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""017_rsa_ddl
2+
3+
Revision ID: a0ba8268868d
4+
Revises: 031148da1d81
5+
Create Date: 2025-06-27 15:05:38.676825
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = 'a0ba8268868d'
15+
down_revision = '031148da1d81'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
op.create_table(
22+
'rsa',
23+
sa.Column('id', sa.BigInteger(), primary_key=True, nullable=False),
24+
sa.Column('private_key', sa.Text(), default="", nullable=False),
25+
sa.Column('public_key', sa.Text(), default="", nullable=False),
26+
sa.Column('salt', sa.Text(), default="", nullable=False),
27+
sa.Column('create_time', sa.BigInteger(), default=0, nullable=False),
28+
sa.Column('update_time', sa.BigInteger(), default=0, nullable=False)
29+
)
30+
op.create_index(op.f('ix_rsa_id'), 'rsa', ['id'], unique=False)
31+
32+
33+
def downgrade():
34+
op.drop_table('rsa')
35+
op.drop_index(op.f('ix_rsa_id'), table_name='rsa')

backend/apps/chat/api/chat.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
1010
delete_chat, list_records
11-
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat, ChatQuestion
11+
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat, ChatQuestion, ChatMcp
1212
from apps.chat.task.llm import LLMService
1313
from apps.datasource.crud.datasource import get_table_schema
1414
from apps.datasource.models.datasource import CoreDatasource
@@ -57,6 +57,18 @@ async def delete(session: SessionDep, chart_id: int):
5757
)
5858

5959

60+
@router.post("/mcp_start", operation_id="mcp_start")
61+
async def mcp_start(session: SessionDep, chat: ChatMcp):
62+
user = await get_current_user(session, chat.token)
63+
return create_chat(session, user, CreateChat(), False)
64+
65+
66+
@router.post("/mcp_question", operation_id="mcp_question")
67+
async def mcp_question(session: SessionDep, chat: ChatMcp):
68+
user = await get_current_user(session, chat.token)
69+
return await stream_sql(session, user, chat)
70+
71+
6072
@router.post("/start")
6173
async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat):
6274
try:
@@ -68,25 +80,6 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
6880
)
6981

7082

71-
@router.post("/mcp_question", operation_id="mcp_question")
72-
async def mcp_question(session: SessionDep, token: str, request_question: ChatQuestion):
73-
user = await get_current_user(session, token)
74-
# return await stream_sql(session, user, request_question)
75-
return {"content":"""步骤1: 确定需要查询的字段。
76-
我们需要统计上海的订单总数,因此需要从"城市"字段中筛选出值为"上海"的记录,并使用COUNT函数计算这些记录的数量。
77-
78-
步骤2: 确定筛选条件。
79-
问题要求统计上海的订单总数,所以我们需要在SQL语句中添加WHERE "城市" = '上海'来筛选出符合条件的记录。
80-
81-
步骤3: 避免关键字冲突。
82-
因为这个Excel/CSV数据库是 PostgreSQL 类型,所以在schema、表名、字段名和别名外层加双引号。
83-
84-
最终答案:
85-
```json
86-
{"success":true,"sql":"SELECT COUNT(*) AS \"TotalOrders\" FROM \"public\".\"Sheet1_c27345b66e\" WHERE \"城市\" = '上海';"}
87-
```"""}
88-
89-
9083
@router.post("/question", operation_id="question")
9184
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion):
9285
"""Stream SQL analysis results
@@ -106,16 +99,17 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
10699
status_code=400,
107100
detail=f"Chat with id {request_question.chat_id} not found"
108101
)
109-
110-
# Get available datasource
111-
ds = session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first()
112-
if not ds:
113-
raise HTTPException(
114-
status_code=500,
115-
detail="No available datasource configuration found"
116-
)
117-
118-
request_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL'
102+
ds: CoreDatasource | None = None
103+
if chat.datasource:
104+
# Get available datasource
105+
ds = session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first()
106+
if not ds:
107+
raise HTTPException(
108+
status_code=500,
109+
detail="No available datasource configuration found"
110+
)
111+
112+
request_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL'
119113

120114
# Get available AI model
121115
aimodel = session.exec(select(AiModelDetail).where(
@@ -128,14 +122,18 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
128122
detail="No available AI model configuration found"
129123
)
130124

131-
history_records: List[ChatRecord] = list_records(session=session, current_user=current_user,
132-
chart_id=request_question.chat_id)
125+
history_records: List[ChatRecord] = list(filter(lambda r: True if r.first_chat != True else False,
126+
list_records(session=session, current_user=current_user,
127+
chart_id=request_question.chat_id)))
133128
# get schema
134-
request_question.db_schema = get_table_schema(session=session, ds=ds)
129+
if ds:
130+
request_question.db_schema = get_table_schema(session=session, ds=ds)
131+
135132
db_user = get_user_info(session=session, user_id=current_user.id)
136133
request_question.lang = db_user.language
137134

138-
llm_service = LLMService(request_question, aimodel, history_records, CoreDatasource(**ds.model_dump()))
135+
llm_service = LLMService(request_question, aimodel, history_records,
136+
CoreDatasource(**ds.model_dump()) if ds else None)
139137

140138
llm_service.init_record(session=session, current_user=current_user)
141139

@@ -144,6 +142,16 @@ def run_task():
144142
# return id
145143
yield orjson.dumps({'type': 'id', 'id': llm_service.get_record().id}).decode() + '\n\n'
146144

145+
# select datasource if datasource is none
146+
if not ds:
147+
ds_res = llm_service.select_datasource(session=session)
148+
for chunk in ds_res:
149+
yield orjson.dumps({'content': chunk, 'type': 'datasource-result'}).decode() + '\n\n'
150+
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
151+
'engine_type': llm_service.ds.type_name, 'type': 'datasource'}).decode() + '\n\n'
152+
153+
llm_service.chat_question.db_schema = get_table_schema(session=session, ds=llm_service.ds)
154+
147155
# generate sql
148156
sql_res = llm_service.generate_sql(session=session)
149157
full_sql_text = ''

backend/apps/chat/curd/chat.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
6060
load_only(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time,
6161
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, ChatRecord.data,
6262
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
63+
ChatRecord.datasource_select_answer, ChatRecord.recommended_question_answer,
64+
ChatRecord.recommended_question,
6365
ChatRecord.predict_data, ChatRecord.finish, ChatRecord.error, ChatRecord.run_time)).filter(
6466
and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(ChatRecord.create_time).all()
6567

@@ -74,7 +76,8 @@ def list_records(session: SessionDep, chart_id: int, current_user: CurrentUser)
7476
return record_list
7577

7678

77-
def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat, require_datasource: bool = True) -> ChatInfo:
79+
def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat,
80+
require_datasource: bool = True) -> ChatInfo:
7881
if not create_chat_obj.datasource and require_datasource:
7982
raise Exception("Datasource cannot be None")
8083

@@ -84,7 +87,7 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
8487
chat = Chat(create_time=datetime.datetime.now(),
8588
create_by=current_user.id,
8689
brief=create_chat_obj.question.strip()[:20])
87-
ds: CoreDatasource = None
90+
ds: CoreDatasource | None = None
8891
if create_chat_obj.datasource:
8992
chat.datasource = create_chat_obj.datasource
9093
ds = session.query(CoreDatasource).filter(CoreDatasource.id == create_chat_obj.datasource).first()
@@ -93,6 +96,8 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
9396
raise Exception(f"Datasource with id {create_chat_obj.datasource} not found")
9497

9598
chat.engine_type = ds.type_name
99+
else:
100+
chat.engine_type = ''
96101

97102
chat_info = ChatInfo(**chat.model_dump())
98103

@@ -102,15 +107,9 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
102107
chat_info.id = chat.id
103108
session.commit()
104109

105-
if not create_chat_obj.datasource:
106-
# use AI to get ds
107-
108-
if not ds:
109-
raise Exception(f"Datasource with id {create_chat_obj.datasource} not found")
110-
111-
112-
chat_info.datasource_exists = True
113-
chat_info.datasource_name = ds.name
110+
if ds:
111+
chat_info.datasource_exists = True
112+
chat_info.datasource_name = ds.name
114113

115114
return chat_info
116115

@@ -205,6 +204,30 @@ def save_full_predict_message_and_answer(session: SessionDep, record_id: int, an
205204
return result
206205

207206

207+
def save_full_select_datasource_message_and_answer(session: SessionDep, record_id: int, answer: str,
208+
full_message: str, datasource: int = None,
209+
engine_type: str = None) -> ChatRecord:
210+
if not record_id:
211+
raise Exception("Record id cannot be None")
212+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
213+
record.full_select_datasource_message = full_message
214+
record.datasource_select_answer = answer
215+
216+
if datasource:
217+
record.datasource = datasource
218+
record.engine_type = engine_type
219+
220+
result = ChatRecord(**record.model_dump())
221+
222+
session.add(record)
223+
session.flush()
224+
session.refresh(record)
225+
226+
session.commit()
227+
228+
return result
229+
230+
208231
def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord:
209232
if not record_id:
210233
raise Exception("Record id cannot be None")

0 commit comments

Comments
 (0)