Skip to content

Commit 0aa42bc

Browse files
committed
feat: improve regenerate SQL chat feature
1 parent 52bba50 commit 0aa42bc

File tree

15 files changed

+334
-35
lines changed

15 files changed

+334
-35
lines changed

backend/alembic/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
# from apps.system.models.user import SQLModel # noqa
2626
# from apps.settings.models.setting_models import SQLModel
27-
# from apps.chat.models.chat_model import SQLModel
27+
from apps.chat.models.chat_model import SQLModel
2828
from apps.terminology.models.terminology_model import SQLModel
2929
#from apps.custom_prompt.models.custom_prompt_model import SQLModel
3030
from apps.data_training.models.data_training_model import SQLModel
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""054_update_chat_record_dll
2+
3+
Revision ID: 24e961f6326b
4+
Revises: 5755c0b95839
5+
Create Date: 2025-12-04 15:51:42.900778
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '24e961f6326b'
15+
down_revision = '5755c0b95839'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.add_column('chat_record', sa.Column('regenerate_record_id', sa.BigInteger(), nullable=True))
23+
# ### end Alembic commands ###
24+
25+
26+
def downgrade():
27+
# ### commands auto generated by Alembic - please adjust! ###
28+
op.drop_column('chat_record', 'regenerate_record_id')
29+
# ### end Alembic commands ###

backend/apps/chat/api/chat.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
1313
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id, \
1414
format_json_data, format_json_list_data, get_chart_config, list_recent_questions
15-
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj
15+
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj, QuickCommand
1616
from apps.chat.task.llm import LLMService
1717
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
18+
from common.utils.command_utils import parse_quick_command
1819
from common.utils.data_format import DataFormat
1920

2021
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
@@ -141,20 +142,99 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, da
141142
return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id)
142143

143144

145+
def find_base_question(record_id: int, session: SessionDep):
146+
stmt = select(ChatRecord.question, ChatRecord.regenerate_record_id).where(
147+
and_(ChatRecord.id == record_id))
148+
_record = session.execute(stmt).fetchone()
149+
if not _record:
150+
raise Exception(f'Cannot find base chat record')
151+
rec_question, rec_regenerate_record_id = _record
152+
if rec_regenerate_record_id:
153+
return find_base_question(rec_regenerate_record_id, session)
154+
else:
155+
return rec_question
156+
157+
144158
@router.post("/question")
159+
async def question_answer(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
160+
current_assistant: CurrentAssistant):
161+
try:
162+
command, text_before_command, record_id, warning_info = parse_quick_command(request_question.question)
163+
if command:
164+
# todo 暂不支持分析和预测,需要改造前端
165+
if command == QuickCommand.ANALYSIS or command == QuickCommand.PREDICT_DATA:
166+
raise Exception(f'Command: {command.value} temporary not supported')
167+
168+
if record_id is not None:
169+
# 排除analysis和predict
170+
stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.analysis_record_id,
171+
ChatRecord.predict_record_id, ChatRecord.regenerate_record_id,
172+
ChatRecord.first_chat).where(
173+
and_(ChatRecord.id == record_id))
174+
_record = session.execute(stmt).fetchone()
175+
if not _record:
176+
raise Exception(f'Record id: {record_id} does not exist')
177+
178+
rec_id, rec_chat_id, rec_analysis_record_id, rec_predict_record_id, rec_regenerate_record_id, rec_first_chat = _record
179+
180+
if rec_chat_id != request_question.chat_id:
181+
raise Exception(f'Record id: {record_id} does not belong to this chat')
182+
if rec_first_chat:
183+
raise Exception(f'Record id: {record_id} does not support this operation')
184+
185+
if command == QuickCommand.REGENERATE:
186+
if rec_analysis_record_id:
187+
raise Exception('Analysis record does not support this operation')
188+
if rec_predict_record_id:
189+
raise Exception('Predict data record does not support this operation')
190+
191+
else: # get last record id
192+
stmt = select(ChatRecord.id, ChatRecord.chat_id, ChatRecord.regenerate_record_id).where(
193+
and_(ChatRecord.chat_id == request_question.chat_id,
194+
ChatRecord.first_chat == False,
195+
ChatRecord.analysis_record_id.is_(None),
196+
ChatRecord.predict_record_id.is_(None))).order_by(
197+
ChatRecord.create_time.desc()).limit(1)
198+
_record = session.execute(stmt).fetchone()
199+
200+
if not _record:
201+
raise Exception(f'You have not ask any question')
202+
203+
rec_id, rec_chat_id, rec_regenerate_record_id = _record
204+
205+
# 没有指定的,就查询上一个
206+
if not rec_regenerate_record_id:
207+
rec_regenerate_record_id = rec_id
208+
209+
# 针对已经是重新生成的提问,需要找到原来的提问是什么
210+
base_question_text = find_base_question(rec_regenerate_record_id, session)
211+
text_before_command = text_before_command + ("\n" if text_before_command else "") + base_question_text
212+
213+
if command == QuickCommand.REGENERATE:
214+
request_question.question = text_before_command
215+
request_question.regenerate_record_id = rec_id
216+
return await stream_sql(session, current_user, request_question, current_assistant)
217+
218+
elif command == QuickCommand.ANALYSIS:
219+
return await analysis_or_predict(session, current_user, rec_id, 'analysis', current_assistant)
220+
221+
elif command == QuickCommand.PREDICT_DATA:
222+
return await analysis_or_predict(session, current_user, rec_id, 'predict', current_assistant)
223+
else:
224+
raise Exception(f'Unknown command: {command.value}')
225+
else:
226+
return await stream_sql(session, current_user, request_question, current_assistant)
227+
except Exception as e:
228+
traceback.print_exc()
229+
230+
def _err(_e: Exception):
231+
yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
232+
233+
return StreamingResponse(_err(e), media_type="text/event-stream")
234+
235+
145236
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
146237
current_assistant: CurrentAssistant):
147-
"""Stream SQL analysis results
148-
149-
Args:
150-
session: Database session
151-
current_user: CurrentUser
152-
request_question: User question model
153-
154-
Returns:
155-
Streaming response with analysis results
156-
"""
157-
158238
try:
159239
llm_service = await LLMService.create(session, current_user, request_question, current_assistant,
160240
embedding=True)
@@ -172,6 +252,12 @@ def _err(_e: Exception):
172252

173253

174254
@router.post("/record/{chat_record_id}/{action_type}")
255+
async def analysis_or_predict_question(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
256+
action_type: str,
257+
current_assistant: CurrentAssistant):
258+
return await analysis_or_predict(session, current_user, chat_record_id, action_type, current_assistant)
259+
260+
175261
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str,
176262
current_assistant: CurrentAssistant):
177263
try:

backend/apps/chat/curd/chat.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import datetime
22
from typing import List
3-
from sqlalchemy import desc, func
43

54
import orjson
65
import sqlparse
76
from sqlalchemy import and_, select, update
7+
from sqlalchemy import desc, func
88
from sqlalchemy.orm import aliased
99

1010
from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion, ChatLog, \
1111
TypeEnum, OperationEnum, ChatRecordResult
12-
from apps.datasource.crud.recommended_problem import get_datasource_recommended, get_datasource_recommended_chart
13-
from apps.datasource.models.datasource import CoreDatasource, DsRecommendedProblem
12+
from apps.datasource.crud.recommended_problem import get_datasource_recommended_chart
13+
from apps.datasource.models.datasource import CoreDatasource
1414
from apps.system.crud.assistant import AssistantOutDsFactory
1515
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
1616
from common.utils.utils import extract_nested_json
@@ -28,11 +28,13 @@ def get_chat_record_by_id(session: SessionDep, record_id: int):
2828
engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by)
2929
return record
3030

31+
3132
def get_chat(session: SessionDep, chat_id: int) -> Chat:
3233
statement = select(Chat).where(Chat.id == chat_id)
3334
chat = session.exec(statement).scalars().first()
3435
return chat
3536

37+
3638
def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]:
3739
oid = current_user.oid if current_user.oid is not None else 1
3840
chart_list = session.query(Chat).filter(and_(Chat.create_by == current_user.id, Chat.oid == oid)).order_by(
@@ -57,6 +59,7 @@ def list_recent_questions(session: SessionDep, current_user: CurrentUser, dataso
5759
)
5860
return [record[0] for record in chat_records] if chat_records else []
5961

62+
6063
def rename_chat(session: SessionDep, rename_object: RenameChat) -> str:
6164
chat = session.get(Chat, rename_object.id)
6265
if not chat:
@@ -191,7 +194,8 @@ def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_
191194

192195

193196
def get_chat_with_records(session: SessionDep, chart_id: int, current_user: CurrentUser,
194-
current_assistant: CurrentAssistant, with_data: bool = False,trans: Trans = None) -> ChatInfo:
197+
current_assistant: CurrentAssistant, with_data: bool = False,
198+
trans: Trans = None) -> ChatInfo:
195199
chat = session.get(Chat, chart_id)
196200
if not chat:
197201
raise Exception(f"Chat with id {chart_id} not found")
@@ -200,7 +204,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
200204

201205
if current_assistant and current_assistant.type in dynamic_ds_types:
202206
out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant)
203-
ds = out_ds_instance.get_ds(chat.datasource,trans)
207+
ds = out_ds_instance.get_ds(chat.datasource, trans)
204208
else:
205209
ds = session.get(CoreDatasource, chat.datasource) if chat.datasource else None
206210

@@ -221,6 +225,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
221225
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql,
222226
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
223227
ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id,
228+
ChatRecord.regenerate_record_id,
224229
ChatRecord.recommended_question, ChatRecord.first_chat,
225230
ChatRecord.finish, ChatRecord.error,
226231
sql_alias_log.reasoning_content.label('sql_reasoning_content'),
@@ -247,6 +252,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
247252
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql,
248253
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
249254
ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id,
255+
ChatRecord.regenerate_record_id,
250256
ChatRecord.recommended_question, ChatRecord.first_chat,
251257
ChatRecord.finish, ChatRecord.error, ChatRecord.data, ChatRecord.predict_data).where(
252258
and_(ChatRecord.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(
@@ -264,6 +270,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
264270
analysis=row.analysis, predict=row.predict,
265271
datasource_select_answer=row.datasource_select_answer,
266272
analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id,
273+
regenerate_record_id=row.regenerate_record_id,
267274
recommended_question=row.recommended_question, first_chat=row.first_chat,
268275
finish=row.finish, error=row.error,
269276
sql_reasoning_content=row.sql_reasoning_content,
@@ -280,6 +287,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
280287
analysis=row.analysis, predict=row.predict,
281288
datasource_select_answer=row.datasource_select_answer,
282289
analysis_record_id=row.analysis_record_id, predict_record_id=row.predict_record_id,
290+
regenerate_record_id=row.regenerate_record_id,
283291
recommended_question=row.recommended_question, first_chat=row.first_chat,
284292
finish=row.finish, error=row.error, data=row.data, predict_data=row.predict_data))
285293

@@ -347,8 +355,9 @@ def format_record(record: ChatRecordResult):
347355

348356
return _dict
349357

358+
350359
def get_chat_brief_generate(session: SessionDep, chat_id: int):
351-
chat = get_chat(session=session,chat_id=chat_id)
360+
chat = get_chat(session=session, chat_id=chat_id)
352361
if chat is not None and chat.brief_generate is not None:
353362
return chat.brief_generate
354363
else:
@@ -468,6 +477,7 @@ def save_question(session: SessionDep, current_user: CurrentUser, question: Chat
468477
record.datasource = chat.datasource
469478
record.engine_type = chat.engine_type
470479
record.ai_modal_id = question.ai_modal_id
480+
record.regenerate_record_id = question.regenerate_record_id
471481

472482
result = ChatRecord(**record.model_dump())
473483

backend/apps/chat/models/chat_model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ class ChatFinishStep(Enum):
4848
GENERATE_CHART = 3
4949

5050

51+
class QuickCommand(Enum):
52+
REGENERATE = '/regenerate'
53+
ANALYSIS = '/analysis'
54+
PREDICT_DATA = '/predict'
55+
56+
5157
# TODO choose table / check connection / generate description
5258

5359
class ChatLog(SQLModel, table=True):
@@ -78,7 +84,7 @@ class Chat(SQLModel, table=True):
7884
datasource: int = Field(sa_column=Column(BigInteger, nullable=True))
7985
engine_type: str = Field(max_length=64)
8086
origin: Optional[int] = Field(
81-
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
87+
sa_column=Column(Integer, nullable=False, default=0)) # 0: default, 1: mcp, 2: assistant
8288
brief_generate: bool = Field(default=False)
8389

8490

@@ -110,6 +116,7 @@ class ChatRecord(SQLModel, table=True):
110116
error: str = Field(sa_column=Column(Text, nullable=True))
111117
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
112118
predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
119+
regenerate_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
113120

114121

115122
class ChatRecordResult(BaseModel):
@@ -134,6 +141,7 @@ class ChatRecordResult(BaseModel):
134141
error: Optional[str] = None
135142
analysis_record_id: Optional[int] = None
136143
predict_record_id: Optional[int] = None
144+
regenerate_record_id: Optional[int] = None
137145
sql_reasoning_content: Optional[str] = None
138146
chart_reasoning_content: Optional[str] = None
139147
analysis_reasoning_content: Optional[str] = None
@@ -184,6 +192,7 @@ class AiModelQuestion(BaseModel):
184192
data_training: str = ""
185193
custom_prompt: str = ""
186194
error_msg: str = ""
195+
regenerate_record_id: Optional[int] = None
187196

188197
def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True):
189198
_sql_template = get_sql_example_template(db_type)
@@ -213,7 +222,10 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T
213222
example_answer_3=_example_answer_3)
214223

215224
def sql_user_question(self, current_time: str, change_title: bool):
216-
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
225+
_question = self.question
226+
if self.regenerate_record_id:
227+
_question = get_sql_template()['regenerate_hint'] + self.question
228+
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=_question,
217229
rule=self.rule, current_time=current_time, error_msg=self.error_msg,
218230
change_title=change_title)
219231

backend/apps/chat/task/llm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ def is_running(self, timeout=0.5):
167167
def init_messages(self):
168168
last_sql_messages: List[dict[str, Any]] = self.generate_sql_logs[-1].messages if len(
169169
self.generate_sql_logs) > 0 else []
170+
if self.chat_question.regenerate_record_id:
171+
# filter record before regenerate_record_id
172+
_temp_log = next(
173+
filter(lambda obj: obj.pid == self.chat_question.regenerate_record_id, self.generate_sql_logs), None)
174+
last_sql_messages: List[dict[str, Any]] = _temp_log.messages if _temp_log else []
170175

171176
# todo maybe can configure
172177
count_limit = 0 - base_message_count_limit
@@ -947,6 +952,11 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
947952
# return id
948953
if in_chat:
949954
yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n'
955+
if self.get_record().regenerate_record_id:
956+
yield 'data:' + orjson.dumps({'type': 'regenerate_record_id',
957+
'regenerate_record_id': self.get_record().regenerate_record_id}).decode() + '\n\n'
958+
yield 'data:' + orjson.dumps(
959+
{'type': 'question', 'question': self.get_record().question}).decode() + '\n\n'
950960
if not stream:
951961
json_result['record_id'] = self.get_record().id
952962

0 commit comments

Comments
 (0)