Skip to content

Commit 6b99a02

Browse files
committed
feat: add chat recommend questions
1 parent c8fc476 commit 6b99a02

File tree

11 files changed

+363
-165
lines changed

11 files changed

+363
-165
lines changed

backend/apps/chat/api/chat.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import traceback
22

3-
import orjson
43
from fastapi import APIRouter, HTTPException
54
from fastapi.responses import StreamingResponse
65

76
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
87
delete_chat
98
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion
10-
from apps.chat.task.llm import LLMService, run_task
9+
from apps.chat.task.llm import LLMService, run_task, run_analysis_or_predict_task, run_recommend_questions_task
1110
from common.core.deps import SessionDep, CurrentUser
1211

1312
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
@@ -62,6 +61,26 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
6261
)
6362

6463

64+
@router.get("/recommend_questions/{chat_record_id}")
65+
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int):
66+
try:
67+
record = session.query(ChatRecord).get(chat_record_id)
68+
if not record:
69+
raise HTTPException(
70+
status_code=400,
71+
detail=f"Chat record with id {chat_record_id} not found"
72+
)
73+
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '')
74+
75+
llm_service = LLMService(session, current_user, request_question)
76+
llm_service.set_record(record)
77+
78+
return run_recommend_questions_task(llm_service)
79+
except Exception as e:
80+
traceback.print_exc()
81+
return '[]'
82+
83+
6584
@router.post("/question")
6685
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion):
6786
"""Stream SQL analysis results
@@ -88,61 +107,38 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
88107
return StreamingResponse(run_task(llm_service, session), media_type="text/event-stream")
89108

90109

91-
@router.post("/record/{chart_record_id}/{action_type}")
92-
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chart_record_id: int, action_type: str):
110+
@router.post("/record/{chat_record_id}/{action_type}")
111+
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str):
93112
if action_type != 'analysis' and action_type != 'predict':
94113
raise HTTPException(
95114
status_code=404,
96115
detail="Not Found"
97116
)
98117

99-
record = session.query(ChatRecord).get(chart_record_id)
118+
record = session.query(ChatRecord).get(chat_record_id)
100119
if not record:
101120
raise HTTPException(
102121
status_code=400,
103-
detail=f"Chat record with id {chart_record_id} not found"
122+
detail=f"Chat record with id {chat_record_id} not found"
104123
)
105124

106125
if not record.chart:
107126
raise HTTPException(
108127
status_code=500,
109-
detail=f"Chat record with id {chart_record_id} has not generated chart, do not support to analyze it"
128+
detail=f"Chat record with id {chat_record_id} has not generated chart, do not support to analyze it"
110129
)
111130

112131
request_question = ChatQuestion(chat_id=record.chat_id, question='')
113132

114-
llm_service = LLMService(session, current_user, request_question)
115-
llm_service.set_record(record)
116-
117-
def run_task():
118-
try:
119-
if action_type == 'analysis':
120-
# generate analysis
121-
analysis_res = llm_service.generate_analysis(session=session)
122-
for chunk in analysis_res:
123-
yield orjson.dumps({'content': chunk, 'type': 'analysis-result'}).decode() + '\n\n'
124-
yield orjson.dumps({'type': 'info', 'msg': 'analysis generated'}).decode() + '\n\n'
125-
126-
yield orjson.dumps({'type': 'analysis_finish'}).decode() + '\n\n'
127-
128-
elif action_type == 'predict':
129-
# generate predict
130-
analysis_res = llm_service.generate_predict()
131-
full_text = ''
132-
for chunk in analysis_res:
133-
yield orjson.dumps({'content': chunk, 'type': 'predict-result'}).decode() + '\n\n'
134-
full_text += chunk
135-
yield orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n'
136-
137-
_data = llm_service.check_save_predict_data(res=full_text)
138-
yield orjson.dumps({'type': 'predict', 'content': _data}).decode() + '\n\n'
139-
140-
yield orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n'
141-
142-
143-
except Exception as e:
144-
traceback.print_exc()
145-
# llm_service.save_error(session=session, message=str(e))
146-
yield orjson.dumps({'content': str(e), 'type': 'error'}).decode() + '\n\n'
133+
try:
134+
llm_service = LLMService(session, current_user, request_question)
135+
llm_service.set_record(record)
136+
except Exception as e:
137+
traceback.print_exc()
138+
raise HTTPException(
139+
status_code=500,
140+
detail=str(e)
141+
)
147142

148-
return StreamingResponse(run_task(), media_type="text/event-stream")
143+
return StreamingResponse(run_analysis_or_predict_task(llm_service, action_type),
144+
media_type="text/event-stream")

backend/apps/chat/curd/chat.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import datetime
22
from typing import List
33

4-
from sqlalchemy import and_
4+
from sqlalchemy import and_, distinct
55
from sqlalchemy.orm import load_only
66

77
from apps.chat.models.chat_model import Chat, ChatRecord, CreateChat, ChatInfo, RenameChat, ChatQuestion
88
from apps.datasource.models.datasource import CoreDatasource
99
from common.core.deps import SessionDep, CurrentUser
10+
from common.utils.utils import extract_nested_json
1011

1112

1213
def list_chats(session: SessionDep, current_user: CurrentUser) -> List[Chat]:
@@ -130,8 +131,6 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
130131
_record.id = record.id
131132
session.commit()
132133

133-
# todo suggest questions
134-
135134
chat_info.records.append(_record)
136135

137136
return chat_info
@@ -252,6 +251,36 @@ def save_full_select_datasource_message_and_answer(session: SessionDep, record_i
252251
return result
253252

254253

254+
def save_full_recommend_question_message_and_answer(session: SessionDep, record_id: int, answer: str,
255+
full_message: str) -> ChatRecord:
256+
if not record_id:
257+
raise Exception("Record id cannot be None")
258+
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
259+
record.full_recommended_question_message = full_message
260+
record.recommended_question_answer = answer
261+
262+
json_str = '[]'
263+
if answer and answer != '':
264+
try:
265+
json_str = extract_nested_json(answer)
266+
267+
if not json_str:
268+
json_str = '[]'
269+
except Exception as e:
270+
pass
271+
record.recommended_question = json_str
272+
273+
result = ChatRecord(**record.model_dump())
274+
275+
session.add(record)
276+
session.flush()
277+
session.refresh(record)
278+
279+
session.commit()
280+
281+
return result
282+
283+
255284
def save_sql(session: SessionDep, record_id: int, sql: str) -> ChatRecord:
256285
if not record_id:
257286
raise Exception("Record id cannot be None")
@@ -379,3 +408,12 @@ def finish_record(session: SessionDep, record_id: int) -> ChatRecord:
379408
session.commit()
380409

381410
return result
411+
412+
413+
def get_old_questions(session: SessionDep, datasource: int):
414+
if not datasource:
415+
return []
416+
records = session.query(ChatRecord.question, ChatRecord.create_time).filter(ChatRecord.datasource == datasource,
417+
ChatRecord.question != None).order_by(
418+
ChatRecord.create_time.desc()).limit(20).all()
419+
return records

backend/apps/chat/models/chat_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ def datasource_sys_question(self):
130130
def datasource_user_question(self, datasource_list: str = "[]"):
131131
return get_datasource_template()['user'].format(question=self.question, data=datasource_list, lang=self.lang)
132132

133-
def datasource_guess_sys_question(self):
133+
def guess_sys_question(self):
134134
return get_guess_question_template()['system']
135135

136-
def datasource_guess_user_question(self, old_questions: str = "[]"):
136+
def guess_user_question(self, old_questions: str = "[]"):
137137
return get_guess_question_template()['user'].format(question=self.question, schema=self.db_schema,
138138
old_questions=old_questions, lang=self.lang)
139139

backend/apps/chat/task/llm.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
from apps.chat.curd.chat import save_question, save_full_sql_message, save_full_sql_message_and_answer, save_sql, \
1818
save_error_message, save_sql_exec_data, save_full_chart_message, save_full_chart_message_and_answer, save_chart, \
1919
finish_record, save_full_analysis_message_and_answer, save_full_predict_message_and_answer, save_predict_data, \
20-
save_full_select_datasource_message_and_answer, list_records
20+
save_full_select_datasource_message_and_answer, list_records, save_full_recommend_question_message_and_answer, \
21+
get_old_questions
2122
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat
2223
from apps.datasource.crud.datasource import get_table_schema
2324
from apps.datasource.models.datasource import CoreDatasource
2425
from apps.db.db import exec_sql
2526
from common.core.config import settings
2627
from common.core.deps import SessionDep, CurrentUser
28+
from common.utils.utils import extract_nested_json
2729

2830
warnings.filterwarnings("ignore")
2931

@@ -59,7 +61,6 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
5961

6062
chat_question.engine = ds.type_name if ds.type != 'excel' else 'PostgreSQL'
6163

62-
6364
history_records: List[ChatRecord] = list(
6465
map(lambda x: ChatRecord(**x.model_dump()), filter(lambda r: True if r.first_chat != True else False,
6566
list_records(session=self.session,
@@ -75,7 +76,6 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
7576
self.chat_question = chat_question
7677
self.config = get_default_config()
7778
self.chat_question.ai_modal_id = self.config.model_id
78-
7979

8080
# Create LLM instance through factory
8181
llm_instance = LLMFactory.create_llm(self.config)
@@ -176,7 +176,7 @@ def get_fields_from_chart(self):
176176
fields.append(column_str)
177177
return fields
178178

179-
def generate_analysis(self, session: SessionDep):
179+
def generate_analysis(self):
180180
fields = self.get_fields_from_chart()
181181

182182
self.chat_question.fields = orjson.dumps(fields).decode()
@@ -189,7 +189,7 @@ def generate_analysis(self, session: SessionDep):
189189
if self.record.full_analysis_message and self.record.full_analysis_message.strip() != '':
190190
history_msg = orjson.loads(self.record.full_analysis_message)
191191

192-
self.record = save_full_analysis_message_and_answer(session=session, record_id=self.record.id, answer='',
192+
self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id, answer='',
193193
full_message=orjson.dumps(history_msg +
194194
[{'type': msg.type,
195195
'content': msg.content} for msg
@@ -210,7 +210,7 @@ def generate_analysis(self, session: SessionDep):
210210
continue
211211

212212
analysis_msg.append(AIMessage(full_analysis_text))
213-
self.record = save_full_analysis_message_and_answer(session=session, record_id=self.record.id,
213+
self.record = save_full_analysis_message_and_answer(session=self.session, record_id=self.record.id,
214214
answer=full_analysis_text,
215215
full_message=orjson.dumps(history_msg +
216216
[{'type': msg.type,
@@ -261,6 +261,47 @@ def generate_predict(self):
261261
in
262262
predict_msg]).decode())
263263

264+
def generate_recommend_questions_task(self):
265+
266+
# get schema
267+
if self.ds and not self.chat_question.db_schema:
268+
self.chat_question.db_schema = get_table_schema(session=self.session, ds=self.ds)
269+
270+
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
271+
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question()))
272+
# todo old questions
273+
old_questions = list(map(lambda q: q[0].strip(), get_old_questions(self.session, self.record.datasource)))
274+
guess_msg.append(HumanMessage(content=self.chat_question.guess_user_question(orjson.dumps(old_questions).decode())))
275+
276+
self.record = save_full_recommend_question_message_and_answer(session=self.session, record_id=self.record.id,
277+
answer='',
278+
full_message=orjson.dumps([{'type': msg.type,
279+
'content': msg.content}
280+
for msg
281+
in
282+
guess_msg]).decode())
283+
284+
full_guess_text = ''
285+
res = self.llm.stream(guess_msg)
286+
for chunk in res:
287+
print(chunk)
288+
if isinstance(chunk, dict):
289+
full_guess_text += chunk['content']
290+
continue
291+
if isinstance(chunk, AIMessageChunk):
292+
full_guess_text += chunk.content
293+
continue
294+
295+
guess_msg.append(AIMessage(full_guess_text))
296+
self.record = save_full_recommend_question_message_and_answer(session=self.session, record_id=self.record.id,
297+
answer=full_guess_text,
298+
full_message=orjson.dumps([{'type': msg.type,
299+
'content': msg.content}
300+
for msg
301+
in
302+
guess_msg]).decode())
303+
return self.record.recommended_question
304+
264305
def select_datasource(self):
265306
datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = []
266307
datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question()))
@@ -486,33 +527,6 @@ def execute_sql(self, sql: str):
486527
return exec_sql(self.ds, sql)
487528

488529

489-
def extract_nested_json(text):
490-
stack = []
491-
start_index = -1
492-
results = []
493-
494-
for i, char in enumerate(text):
495-
if char in '{[':
496-
if not stack: # 记录起始位置
497-
start_index = i
498-
stack.append(char)
499-
elif char in '}]':
500-
if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')):
501-
stack.pop()
502-
if not stack: # 栈空时截取完整JSON
503-
json_str = text[start_index:i + 1]
504-
try:
505-
orjson.loads(json_str) # 验证有效性
506-
results.append(json_str)
507-
except:
508-
pass
509-
else:
510-
stack = [] # 括号不匹配则重置
511-
if len(results) > 0 and results[0]:
512-
return results[0]
513-
return None
514-
515-
516530
def execute_sql_with_db(db: SQLDatabase, sql: str) -> str:
517531
"""Execute SQL query using SQLDatabase
518532
@@ -647,6 +661,42 @@ def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True)
647661
yield f'> ❌ **ERROR**\n\n> \n\n> {str(e)}。'
648662

649663

664+
def run_analysis_or_predict_task(llm_service: LLMService, action_type: str):
665+
try:
666+
if action_type == 'analysis':
667+
# generate analysis
668+
analysis_res = llm_service.generate_analysis()
669+
for chunk in analysis_res:
670+
yield orjson.dumps({'content': chunk, 'type': 'analysis-result'}).decode() + '\n\n'
671+
yield orjson.dumps({'type': 'info', 'msg': 'analysis generated'}).decode() + '\n\n'
672+
673+
yield orjson.dumps({'type': 'analysis_finish'}).decode() + '\n\n'
674+
675+
elif action_type == 'predict':
676+
# generate predict
677+
analysis_res = llm_service.generate_predict()
678+
full_text = ''
679+
for chunk in analysis_res:
680+
yield orjson.dumps({'content': chunk, 'type': 'predict-result'}).decode() + '\n\n'
681+
full_text += chunk
682+
yield orjson.dumps({'type': 'info', 'msg': 'predict generated'}).decode() + '\n\n'
683+
684+
_data = llm_service.check_save_predict_data(res=full_text)
685+
yield orjson.dumps({'type': 'predict', 'content': _data}).decode() + '\n\n'
686+
687+
yield orjson.dumps({'type': 'predict_finish'}).decode() + '\n\n'
688+
689+
690+
except Exception as e:
691+
traceback.print_exc()
692+
# llm_service.save_error(session=session, message=str(e))
693+
yield orjson.dumps({'content': str(e), 'type': 'error'}).decode() + '\n\n'
694+
695+
696+
def run_recommend_questions_task(llm_service: LLMService):
697+
return llm_service.generate_recommend_questions_task()
698+
699+
650700
def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
651701
file_name = f'c_{chat_id}_r_{record_id}'
652702

0 commit comments

Comments
 (0)