Skip to content

Commit 6d39a77

Browse files
committed
feat: split analysis & predict in chat
1 parent f556e1d commit 6d39a77

File tree

8 files changed

+164
-65
lines changed

8 files changed

+164
-65
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""024_modify_chat_record
2+
3+
Revision ID: 806bc67ff45f
4+
Revises: f535d09946f6
5+
Create Date: 2025-07-11 18:09:52.417628
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '806bc67ff45f'
15+
down_revision = 'f535d09946f6'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.add_column('chat_record', sa.Column('analysis_record_id', sa.BigInteger(), nullable=True))
23+
op.add_column('chat_record', sa.Column('predict_record_id', sa.BigInteger(), nullable=True))
24+
op.drop_column('chat_record', 'run_time')
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade():
29+
# ### commands auto generated by Alembic - please adjust! ###
30+
op.add_column('chat_record', sa.Column('run_time', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=False))
31+
op.drop_column('chat_record', 'predict_record_id')
32+
op.drop_column('chat_record', 'analysis_record_id')
33+
# ### end Alembic commands ###

backend/apps/chat/api/chat.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
129129
detail=str(e)
130130
)
131131

132-
return StreamingResponse(run_task(llm_service, session), media_type="text/event-stream")
132+
return StreamingResponse(run_task(llm_service), media_type="text/event-stream")
133133

134134

135135
@router.post("/record/{chat_record_id}/{action_type}")
@@ -157,13 +157,12 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
157157

158158
try:
159159
llm_service = LLMService(session, current_user, request_question)
160-
llm_service.set_record(record)
161160
except Exception as e:
162161
traceback.print_exc()
163162
raise HTTPException(
164163
status_code=500,
165164
detail=str(e)
166165
)
167166

168-
return StreamingResponse(run_analysis_or_predict_task(llm_service, action_type),
167+
return StreamingResponse(run_analysis_or_predict_task(llm_service, action_type, record),
169168
media_type="text/event-stream")

backend/apps/chat/curd/chat.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
8282
load_only(ChatRecord.id, ChatRecord.chat_id, ChatRecord.create_time, ChatRecord.finish_time,
8383
ChatRecord.question, ChatRecord.sql_answer, ChatRecord.sql, ChatRecord.data,
8484
ChatRecord.chart_answer, ChatRecord.chart, ChatRecord.analysis, ChatRecord.predict,
85-
ChatRecord.datasource_select_answer,
85+
ChatRecord.datasource_select_answer, ChatRecord.analysis_record_id, ChatRecord.predict_record_id,
8686
ChatRecord.recommended_question, ChatRecord.first_chat,
87-
ChatRecord.predict_data, ChatRecord.finish, ChatRecord.error, ChatRecord.run_time)).filter(
87+
ChatRecord.predict_data, ChatRecord.finish, ChatRecord.error)).filter(
8888
and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(ChatRecord.create_time).all()
8989

9090
result = list(map(format_record, record_list))
@@ -130,9 +130,11 @@ def format_record(record: ChatRecord):
130130
return _dict
131131

132132

133-
def list_records(session: SessionDep, chart_id: int, current_user: CurrentUser) -> List[ChatRecord]:
133+
def list_base_records(session: SessionDep, chart_id: int, current_user: CurrentUser) -> List[ChatRecord]:
134134
record_list = session.query(ChatRecord).filter(
135-
and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id)).order_by(ChatRecord.create_time).all()
135+
and_(Chat.create_by == current_user.id, ChatRecord.chat_id == chart_id,
136+
ChatRecord.analysis_record_id is None, ChatRecord.predict_record_id is None)).order_by(
137+
ChatRecord.create_time).all()
136138
return record_list
137139

138140

@@ -225,6 +227,34 @@ def save_question(session: SessionDep, current_user: CurrentUser, question: Chat
225227
return result
226228

227229

230+
def save_analysis_predict_record(session: SessionDep, base_record: ChatRecord, action_type: str) -> ChatRecord:
231+
record = ChatRecord()
232+
record.question = base_record.question
233+
record.chat_id = base_record.chat_id
234+
record.datasource = base_record.datasource
235+
record.engine_type = base_record.engine_type
236+
record.ai_modal_id = base_record.ai_modal_id
237+
record.create_time = datetime.datetime.now()
238+
record.create_by = base_record.id
239+
record.chart = base_record.chart
240+
record.data = base_record.data
241+
242+
if action_type == 'analysis':
243+
record.analysis_record_id = base_record.id
244+
elif action_type == 'predict':
245+
record.predict_record_id = base_record.id
246+
247+
result = ChatRecord(**record.model_dump())
248+
249+
session.add(record)
250+
session.flush()
251+
session.refresh(record)
252+
result.id = record.id
253+
session.commit()
254+
255+
return result
256+
257+
228258
def save_full_sql_message(session: SessionDep, record_id: int, full_message: str) -> ChatRecord:
229259
return save_full_sql_message_and_answer(session=session, record_id=record_id, full_message=full_message, answer='')
230260

backend/apps/chat/models/chat_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class ChatRecord(SQLModel, table=True):
6262
token_select_datasource_question: str = Field(max_length=256, nullable=True)
6363
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
6464
error: str = Field(sa_column=Column(Text, nullable=True))
65-
run_time: float = Field(default=0)
66-
65+
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
66+
predict_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
6767

6868
class CreateChat(BaseModel):
6969
id: int = None

backend/apps/chat/task/llm.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from apps.chat.curd.chat import save_question, save_full_sql_message, save_full_sql_message_and_answer, save_sql, \
1818
save_error_message, save_sql_exec_data, save_full_chart_message, save_full_chart_message_and_answer, save_chart, \
1919
finish_record, save_full_analysis_message_and_answer, save_full_predict_message_and_answer, save_predict_data, \
20-
save_full_select_datasource_message_and_answer, list_records, save_full_recommend_question_message_and_answer, \
21-
get_old_questions
20+
save_full_select_datasource_message_and_answer, save_full_recommend_question_message_and_answer, \
21+
get_old_questions, save_analysis_predict_record, list_base_records
2222
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat
2323
from apps.datasource.crud.datasource import get_table_schema
2424
from apps.datasource.models.datasource import CoreDatasource
@@ -63,9 +63,9 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
6363

6464
history_records: List[ChatRecord] = list(
6565
map(lambda x: ChatRecord(**x.model_dump()), filter(lambda r: True if r.first_chat != True else False,
66-
list_records(session=self.session,
67-
current_user=current_user,
68-
chart_id=chat_question.chat_id))))
66+
list_base_records(session=self.session,
67+
current_user=current_user,
68+
chart_id=chat_question.chat_id))))
6969
# get schema
7070
if ds:
7171
chat_question.db_schema = get_table_schema(session=self.session, ds=ds)
@@ -606,7 +606,7 @@ def execute_sql_with_db(db: SQLDatabase, sql: str) -> str:
606606
raise RuntimeError(error_msg)
607607

608608

609-
def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True):
609+
def run_task(llm_service: LLMService, in_chat: bool = True):
610610
try:
611611
# return id
612612
if in_chat:
@@ -626,7 +626,7 @@ def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True)
626626
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
627627
'engine_type': llm_service.ds.type_name, 'type': 'datasource'}).decode() + '\n\n'
628628

629-
llm_service.chat_question.db_schema = get_table_schema(session=session, ds=llm_service.ds)
629+
llm_service.chat_question.db_schema = get_table_schema(session=llm_service.session, ds=llm_service.ds)
630630

631631
# generate sql
632632
sql_res = llm_service.generate_sql()
@@ -720,8 +720,10 @@ def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True)
720720
yield f'> ❌ **ERROR**\n\n> \n\n> {str(e)}。'
721721

722722

723-
def run_analysis_or_predict_task(llm_service: LLMService, action_type: str):
723+
def run_analysis_or_predict_task(llm_service: LLMService, action_type: str, base_record: ChatRecord):
724724
try:
725+
llm_service.set_record(save_analysis_predict_record(llm_service.session, base_record, action_type))
726+
725727
if action_type == 'analysis':
726728
# generate analysis
727729
analysis_res = llm_service.generate_analysis()
@@ -752,10 +754,10 @@ def run_analysis_or_predict_task(llm_service: LLMService, action_type: str):
752754

753755
yield orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n'
754756

755-
757+
llm_service.finish()
756758
except Exception as e:
757759
traceback.print_exc()
758-
# llm_service.save_error(session=session, message=str(e))
760+
llm_service.save_error(message=str(e))
759761
yield orjson.dumps({'content': str(e), 'type': 'error'}).decode() + '\n\n'
760762
finally:
761763
# end

backend/apps/mcp/mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,4 @@ async def mcp_question(session: SessionDep, chat: ChatMcp):
6666
llm_service = LLMService(session, user, chat)
6767
llm_service.init_record()
6868

69-
return StreamingResponse(run_task(llm_service, session, False), media_type="text/event-stream")
69+
return StreamingResponse(run_task(llm_service, False), media_type="text/event-stream")

frontend/src/api/chat.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ export class ChatRecord {
4747
run_time: number = 0
4848
first_chat: boolean = false
4949
recommended_question?: string
50+
analysis_record_id?: number
51+
predict_record_id?: number
5052

5153
constructor()
5254
constructor(
@@ -69,7 +71,9 @@ export class ChatRecord {
6971
error: string | undefined,
7072
run_time: number,
7173
first_chat: boolean,
72-
recommended_question: string | undefined
74+
recommended_question: string | undefined,
75+
analysis_record_id: number | undefined,
76+
predict_record_id: number | undefined
7377
)
7478
constructor(
7579
id?: number,
@@ -91,7 +95,9 @@ export class ChatRecord {
9195
error?: string,
9296
run_time?: number,
9397
first_chat?: boolean,
94-
recommended_question?: string
98+
recommended_question?: string,
99+
analysis_record_id?: number,
100+
predict_record_id?: number
95101
) {
96102
this.id = id
97103
this.chat_id = chat_id
@@ -113,6 +119,8 @@ export class ChatRecord {
113119
this.run_time = run_time ?? 0
114120
this.first_chat = !!first_chat
115121
this.recommended_question = recommended_question
122+
this.analysis_record_id = analysis_record_id
123+
this.predict_record_id = predict_record_id
116124
}
117125
}
118126

@@ -235,7 +243,9 @@ const toChatRecord = (data?: any): ChatRecord | undefined => {
235243
data.error,
236244
data.run_time,
237245
data.first_chat,
238-
data.recommended_question
246+
data.recommended_question,
247+
data.analysis_record_id,
248+
data.predict_record_id
239249
)
240250
}
241251
const toChatRecordList = (list: any = []): ChatRecord[] => {

0 commit comments

Comments
 (0)