Skip to content

Commit 60e9b7a

Browse files
committed
feat: auto select datasource
1 parent a0b9c08 commit 60e9b7a

File tree

12 files changed

+317
-51
lines changed

12 files changed

+317
-51
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""016_modify_chat
2+
3+
Revision ID: 031148da1d81
4+
Revises: 02d84523a979
5+
Create Date: 2025-06-26 17:00:07.054531
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '031148da1d81'
15+
down_revision = '02d84523a979'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.alter_column('chat', 'datasource',
23+
existing_type=sa.INTEGER(),
24+
nullable=True)
25+
op.add_column('chat_record', sa.Column('ai_modal_id', sa.Integer(), nullable=True))
26+
op.add_column('chat_record', sa.Column('first_chat', sa.Boolean(), nullable=True))
27+
op.add_column('chat_record', sa.Column('recommended_question_answer', sa.Text(), nullable=True))
28+
op.add_column('chat_record', sa.Column('recommended_question', sa.Text(), nullable=True))
29+
op.add_column('chat_record', sa.Column('datasource_select_answer', sa.Text(), nullable=True))
30+
op.add_column('chat_record', sa.Column('token_sql', sa.Integer(), nullable=True))
31+
op.add_column('chat_record', sa.Column('token_chart', sa.Integer(), nullable=True))
32+
op.add_column('chat_record', sa.Column('token_analysis', sa.Integer(), nullable=True))
33+
op.add_column('chat_record', sa.Column('token_predict', sa.Integer(), nullable=True))
34+
op.add_column('chat_record', sa.Column('full_recommended_question_message', sa.Text(), nullable=True))
35+
op.add_column('chat_record', sa.Column('token_recommended_question', sa.Integer(), nullable=True))
36+
op.add_column('chat_record', sa.Column('full_select_datasource_message', sa.Text(), nullable=True))
37+
op.add_column('chat_record', sa.Column('token_select_datasource_question', sa.Integer(), nullable=True))
38+
op.alter_column('chat_record', 'chat_id',
39+
existing_type=sa.INTEGER(),
40+
nullable=False)
41+
op.alter_column('chat_record', 'datasource',
42+
existing_type=sa.INTEGER(),
43+
nullable=True)
44+
# ### end Alembic commands ###
45+
46+
47+
def downgrade():
48+
# ### commands auto generated by Alembic - please adjust! ###
49+
op.alter_column('chat_record', 'datasource',
50+
existing_type=sa.INTEGER(),
51+
nullable=False)
52+
op.alter_column('chat_record', 'chat_id',
53+
existing_type=sa.INTEGER(),
54+
nullable=True)
55+
op.drop_column('chat_record', 'token_select_datasource_question')
56+
op.drop_column('chat_record', 'full_select_datasource_message')
57+
op.drop_column('chat_record', 'token_recommended_question')
58+
op.drop_column('chat_record', 'full_recommended_question_message')
59+
op.drop_column('chat_record', 'token_predict')
60+
op.drop_column('chat_record', 'token_analysis')
61+
op.drop_column('chat_record', 'token_chart')
62+
op.drop_column('chat_record', 'token_sql')
63+
op.drop_column('chat_record', 'datasource_select_answer')
64+
op.drop_column('chat_record', 'recommended_question')
65+
op.drop_column('chat_record', 'recommended_question_answer')
66+
op.drop_column('chat_record', 'first_chat')
67+
op.drop_column('chat_record', 'ai_modal_id')
68+
op.alter_column('chat', 'datasource',
69+
existing_type=sa.INTEGER(),
70+
nullable=False)
71+
# ### end Alembic commands ###

backend/apps/chat/api/chat.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
7272
async def mcp_question(session: SessionDep, token: str, request_question: ChatQuestion):
7373
user = await get_current_user(session, token)
7474
# return await stream_sql(session, user, request_question)
75-
return {"content":"""步骤1: 确定需要查询的字段。
75+
return {"content": """步骤1: 确定需要查询的字段。
7676
我们需要统计上海的订单总数,因此需要从"城市"字段中筛选出值为"上海"的记录,并使用COUNT函数计算这些记录的数量。
7777
7878
步骤2: 确定筛选条件。
@@ -106,16 +106,17 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
106106
status_code=400,
107107
detail=f"Chat with id {request_question.chat_id} not found"
108108
)
109-
110-
# Get available datasource
111-
ds = session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first()
112-
if not ds:
113-
raise HTTPException(
114-
status_code=500,
115-
detail="No available datasource configuration found"
116-
)
117-
118-
request_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL'
109+
ds: CoreDatasource | None = None
110+
if chat.datasource:
111+
# Get available datasource
112+
ds = session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first()
113+
if not ds:
114+
raise HTTPException(
115+
status_code=500,
116+
detail="No available datasource configuration found"
117+
)
118+
119+
request_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL'
119120

120121
# Get available AI model
121122
aimodel = session.exec(select(AiModelDetail).where(
@@ -128,14 +129,18 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
128129
detail="No available AI model configuration found"
129130
)
130131

131-
history_records: List[ChatRecord] = list_records(session=session, current_user=current_user,
132-
chart_id=request_question.chat_id)
132+
history_records: List[ChatRecord] = list(filter(lambda r: True if r.first_chat != True else False,
133+
list_records(session=session, current_user=current_user,
134+
chart_id=request_question.chat_id)))
133135
# get schema
134-
request_question.db_schema = get_table_schema(session=session, ds=ds)
136+
if ds:
137+
request_question.db_schema = get_table_schema(session=session, ds=ds)
138+
135139
db_user = get_user_info(session=session, user_id=current_user.id)
136140
request_question.lang = db_user.language
137141

138-
llm_service = LLMService(request_question, aimodel, history_records, CoreDatasource(**ds.model_dump()))
142+
llm_service = LLMService(request_question, aimodel, history_records,
143+
CoreDatasource(**ds.model_dump()) if ds else None)
139144

140145
llm_service.init_record(session=session, current_user=current_user)
141146

@@ -144,6 +149,16 @@ def run_task():
144149
# return id
145150
yield orjson.dumps({'type': 'id', 'id': llm_service.get_record().id}).decode() + '\n\n'
146151

152+
# select datasource if datasource is none
153+
if not ds:
154+
ds_res = llm_service.select_datasource(session=session)
155+
for chunk in ds_res:
156+
yield orjson.dumps({'content': chunk, 'type': 'datasource-result'}).decode() + '\n\n'
157+
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
158+
'engine_type': llm_service.ds.type_name, 'type': 'datasource'}).decode() + '\n\n'
159+
160+
llm_service.chat_question.db_schema = get_table_schema(session=session, ds=llm_service.ds)
161+
147162
# generate sql
148163
sql_res = llm_service.generate_sql(session=session)
149164
full_sql_text = ''

backend/apps/chat/curd/chat.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
6060
load_only(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time,
6161
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, ChatRecord.data,
6262
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
63+
ChatRecord.datasource_select_answer, ChatRecord.recommended_question_answer,
64+
ChatRecord.recommended_question,
6365
ChatRecord.predict_data, ChatRecord.finish, ChatRecord.error, ChatRecord.run_time)).filter(
6466
and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(ChatRecord.create_time).all()
6567

@@ -74,7 +76,8 @@ def list_records(session: SessionDep, chart_id: int, current_user: CurrentUser)
7476
return record_list
7577

7678

77-
def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat, require_datasource: bool = True) -> ChatInfo:
79+
def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat,
80+
require_datasource: bool = True) -> ChatInfo:
7881
if not create_chat_obj.datasource and require_datasource:
7982
raise Exception("Datasource cannot be None")
8083

@@ -84,7 +87,7 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
8487
chat = Chat(create_time=datetime.datetime.now(),
8588
create_by=current_user.id,
8689
brief=create_chat_obj.question.strip()[:20])
87-
ds: CoreDatasource = None
90+
ds: CoreDatasource | None = None
8891
if create_chat_obj.datasource:
8992
chat.datasource = create_chat_obj.datasource
9093
ds = session.query(CoreDatasource).filter(CoreDatasource.id == create_chat_obj.datasource).first()
@@ -102,15 +105,9 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
102105
chat_info.id = chat.id
103106
session.commit()
104107

105-
if not create_chat_obj.datasource:
106-
# use AI to get ds
107-
108-
if not ds:
109-
raise Exception(f"Datasource with id {create_chat_obj.datasource} not found")
110-
111-
112-
chat_info.datasource_exists = True
113-
chat_info.datasource_name = ds.name
108+
if ds:
109+
chat_info.datasource_exists = True
110+
chat_info.datasource_name = ds.name
114111

115112
return chat_info
116113

@@ -205,6 +202,30 @@ def save_full_predict_message_and_answer(session: SessionDep, record_id: int, an
205202
return result
206203

207204

205+
def save_full_select_datasource_message_and_answer(session: SessionDep, record_id: int, answer: str,
206+
full_message: str, datasource: int = None,
207+
engine_type: str = None) -> ChatRecord:
208+
if not record_id:
209+
raise Exception("Record id cannot be None")
210+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
211+
record.full_select_datasource_message = full_message
212+
record.datasource_select_answer = answer
213+
214+
if datasource:
215+
record.datasource = datasource
216+
record.engine_type = engine_type
217+
218+
result = ChatRecord(**record.model_dump())
219+
220+
session.add(record)
221+
session.flush()
222+
session.refresh(record)
223+
224+
session.commit()
225+
226+
return result
227+
228+
208229
def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord:
209230
if not record_id:
210231
raise Exception("Record id cannot be None")

backend/apps/chat/models/chat_model.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
from apps.template.generate_analysis.generator import get_analysis_template
99
from apps.template.generate_chart.generator import get_chart_template
10+
from apps.template.generate_guess_question.generator import get_guess_question_template
1011
from apps.template.generate_predict.generator import get_predict_template
1112
from apps.template.generate_sql.generator import get_sql_template
13+
from apps.template.select_datasource.generator import get_datasource_template
1214

1315

1416
class Chat(SQLModel, table=True):
@@ -25,11 +27,13 @@ class Chat(SQLModel, table=True):
2527
class ChatRecord(SQLModel, table=True):
2628
__tablename__ = "chat_record"
2729
id: Optional[int] = Field(sa_column=Column(Integer, Identity(always=True), primary_key=True))
28-
chat_id: int = Field(sa_column=Column(Integer))
30+
chat_id: int = Field(sa_column=Column(Integer, nullable=False))
31+
ai_modal_id: Optional[int] = Field(sa_column=Column(Integer))
32+
first_chat: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
2933
create_time: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
3034
finish_time: datetime = Field(sa_column=Column(DateTime(timezone=True), nullable=True))
3135
create_by: int = Field(sa_column=Column(BigInteger, nullable=True))
32-
datasource: int = Field(sa_column=Column(Integer, nullable=False))
36+
datasource: int = Field(sa_column=Column(Integer, nullable=True))
3337
engine_type: str = Field(max_length=64)
3438
question: str = Field(sa_column=Column(Text, nullable=True))
3539
sql_answer: str = Field(sa_column=Column(Text, nullable=True))
@@ -41,10 +45,21 @@ class ChatRecord(SQLModel, table=True):
4145
analysis: str = Field(sa_column=Column(Text, nullable=True))
4246
predict: str = Field(sa_column=Column(Text, nullable=True))
4347
predict_data: str = Field(sa_column=Column(Text, nullable=True))
48+
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
49+
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
50+
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
4451
full_sql_message: str = Field(sa_column=Column(Text, nullable=True))
52+
token_sql: int = Field(default=0, nullable=True)
4553
full_chart_message: str = Field(sa_column=Column(Text, nullable=True))
54+
token_chart: int = Field(default=0, nullable=True)
4655
full_analysis_message: str = Field(sa_column=Column(Text, nullable=True))
56+
token_analysis: int = Field(default=0, nullable=True)
4757
full_predict_message: str = Field(sa_column=Column(Text, nullable=True))
58+
token_predict: int = Field(default=0, nullable=True)
59+
full_recommended_question_message: str = Field(sa_column=Column(Text, nullable=True))
60+
token_recommended_question: int = Field(default=0, nullable=True)
61+
full_select_datasource_message: str = Field(sa_column=Column(Text, nullable=True))
62+
token_select_datasource_question: int = Field(default=0, nullable=True)
4863
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
4964
error: str = Field(sa_column=Column(Text, nullable=True))
5065
run_time: float = Field(default=0)
@@ -67,7 +82,7 @@ class ChatInfo(BaseModel):
6782
create_by: int = None
6883
brief: str = ''
6984
chat_type: str = "chat"
70-
datasource: int = None
85+
datasource: Optional[int] = None
7186
engine_type: str = ''
7287
datasource_name: str = ''
7388
datasource_exists: bool = True
@@ -108,6 +123,19 @@ def predict_sys_question(self):
108123
def predict_user_question(self):
109124
return get_predict_template()['user'].format(fields=self.fields, data=self.data, lang=self.lang)
110125

126+
def datasource_sys_question(self):
127+
return get_datasource_template()['system']
128+
129+
def datasource_user_question(self, datasource_list: str = "[]"):
130+
return get_datasource_template()['user'].format(question=self.question, data=datasource_list, lang=self.lang)
131+
132+
def datasource_guess_sys_question(self):
133+
return get_guess_question_template()['system']
134+
135+
def datasource_guess_user_question(self, old_questions: str = "[]"):
136+
return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema,
137+
old_questions=old_questions, lang=self.lang)
138+
111139

112140
class ChatQuestion(AiModelQuestion):
113141
question: str

0 commit comments

Comments
 (0)