Skip to content

Commit 1849f9f

Browse files
WainWongclaude
andcommitted
feat(embedding): support multi-turn conversation context for table RAG
- Add MULTI_TURN_EMBEDDING_ENABLED and MULTI_TURN_HISTORY_COUNT config - Add get_chat_history_questions() to retrieve recent questions from same chat - Add build_context_query() to concatenate history questions with current question - Update calc_table_embedding() to use context query for better table matching - Pass history_questions through get_table_schema() to LLMService This improves table structure retrieval accuracy by considering the full conversation context instead of just the latest question. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent a6dcec0 commit 1849f9f

File tree

5 files changed

+100
-8
lines changed

5 files changed

+100
-8
lines changed

backend/apps/chat/curd/chat.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,36 @@ def get_old_questions(session: SessionDep, datasource: int):
839839
for r in result:
840840
records.append(r.question)
841841
return records
842+
843+
844+
def get_chat_history_questions(session: SessionDep, chat_id: int, limit: int = 3) -> List[str]:
845+
"""
846+
获取当前chat的历史问题列表(按时间正序,最旧的在前)
847+
848+
Args:
849+
session: 数据库会话
850+
chat_id: 当前对话ID
851+
limit: 获取的历史问题数量
852+
853+
Returns:
854+
历史问题列表,按时间正序排列
855+
"""
856+
stmt = (
857+
select(ChatRecord.question)
858+
.where(
859+
and_(
860+
ChatRecord.chat_id == chat_id,
861+
ChatRecord.question.isnot(None),
862+
ChatRecord.question != '',
863+
ChatRecord.error.is_(None)
864+
)
865+
)
866+
.order_by(ChatRecord.create_time.desc())
867+
.limit(limit)
868+
)
869+
870+
result = session.execute(stmt)
871+
questions = [row.question for row in result if row.question and row.question.strip()]
872+
873+
# 反转列表,使最旧的在前
874+
return list(reversed(questions))

backend/apps/chat/task/llm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
3232
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
3333
get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate, get_chat_predict_data, \
34-
get_chat_chart_config
34+
get_chat_chart_config, get_chat_history_questions
3535
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
3636
ChatFinishStep, AxisObj
3737
from apps.data_training.curd.data_training import get_training_template
@@ -101,6 +101,12 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
101101
chat: Chat | None = session.get(Chat, chat_id)
102102
if not chat:
103103
raise SingleMessageError(f"Chat with id {chat_id} not found")
104+
105+
# 获取历史问题(用于多轮对话embedding)
106+
history_questions = []
107+
if settings.MULTI_TURN_EMBEDDING_ENABLED:
108+
history_questions = get_chat_history_questions(session, chat_id, settings.MULTI_TURN_HISTORY_COUNT)
109+
104110
ds: CoreDatasource | AssistantOutDsSchema | None = None
105111
if chat.datasource:
106112
# Get available datasource
@@ -117,7 +123,8 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
117123
raise SingleMessageError("No available datasource configuration found")
118124
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
119125
chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds,
120-
question=chat_question.question, embedding=embedding)
126+
question=chat_question.question, embedding=embedding,
127+
history_questions=history_questions)
121128

122129
self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
123130
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)

backend/apps/datasource/crud/datasource.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
416416

417417

418418
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
419-
embedding: bool = True) -> str:
419+
embedding: bool = True, history_questions: List[str] = None) -> str:
420420
schema_str = ""
421421
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
422422
if len(table_objs) == 0:
@@ -455,7 +455,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
455455

456456
# do table embedding
457457
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
458-
tables = calc_table_embedding(tables, question)
458+
tables = calc_table_embedding(tables, question, history_questions)
459459
# splice schema
460460
if tables:
461461
for s in tables:

backend/apps/datasource/embedding/table_embedding.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,46 @@
33
import json
44
import time
55
import traceback
6+
from typing import List
67

78
from apps.ai_model.embedding import EmbeddingModelCache
89
from apps.datasource.embedding.utils import cosine_similarity
910
from common.core.config import settings
1011
from common.utils.utils import SQLBotLogUtil
1112

1213

13-
def get_table_embedding(tables: list[dict], question: str):
14+
def build_context_query(current_question: str, history_questions: List[str] = None) -> str:
15+
"""
16+
构建包含上下文的查询文本
17+
18+
Args:
19+
current_question: 当前问题
20+
history_questions: 历史问题列表(按时间正序,最旧的在前)
21+
22+
Returns:
23+
拼接后的查询文本
24+
"""
25+
if not settings.MULTI_TURN_EMBEDDING_ENABLED or not history_questions:
26+
return current_question
27+
28+
max_history = settings.MULTI_TURN_HISTORY_COUNT
29+
recent_history = history_questions[-max_history:] if history_questions else []
30+
31+
if not recent_history:
32+
return current_question
33+
34+
# 拼接:历史问题 + 当前问题
35+
context_parts = recent_history + [current_question]
36+
37+
# 使用分隔符拼接,保持语义连贯
38+
context_query = " | ".join(context_parts)
39+
40+
SQLBotLogUtil.info(f"Context query for embedding: {context_query}")
41+
42+
return context_query
43+
44+
45+
def get_table_embedding(tables: list[dict], question: str, history_questions: List[str] = None):
1446
_list = []
1547
for table in tables:
1648
_list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0})
@@ -25,7 +57,9 @@ def get_table_embedding(tables: list[dict], question: str):
2557
end_time = time.time()
2658
SQLBotLogUtil.info(str(end_time - start_time))
2759

28-
q_embedding = model.embed_query(question)
60+
# 构建包含上下文的查询
61+
context_query = build_context_query(question, history_questions)
62+
q_embedding = model.embed_query(context_query)
2963
for index in range(len(results)):
3064
item = results[index]
3165
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
@@ -40,7 +74,18 @@ def get_table_embedding(tables: list[dict], question: str):
4074
return _list
4175

4276

43-
def calc_table_embedding(tables: list[dict], question: str):
77+
def calc_table_embedding(tables: list[dict], question: str, history_questions: List[str] = None):
78+
"""
79+
计算表结构与问题的embedding相似度
80+
81+
Args:
82+
tables: 表结构列表
83+
question: 当前问题
84+
history_questions: 历史问题列表(可选,用于多轮对话)
85+
86+
Returns:
87+
按相似度排序的表列表
88+
"""
4489
_list = []
4590
for table in tables:
4691
_list.append(
@@ -58,7 +103,9 @@ def calc_table_embedding(tables: list[dict], question: str):
58103
# SQLBotLogUtil.info(str(end_time - start_time))
59104
results = [item.get('embedding') for item in _list]
60105

61-
q_embedding = model.embed_query(question)
106+
# 构建包含上下文的查询
107+
context_query = build_context_query(question, history_questions)
108+
q_embedding = model.embed_query(context_query)
62109
for index in range(len(results)):
63110
item = results[index]
64111
if item:

backend/common/core/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
115115
TABLE_EMBEDDING_COUNT: int = 10
116116
DS_EMBEDDING_COUNT: int = 10
117117

118+
# Multi-turn embedding settings
119+
MULTI_TURN_EMBEDDING_ENABLED: bool = True
120+
MULTI_TURN_HISTORY_COUNT: int = 3
121+
118122
ORACLE_CLIENT_PATH: str = '/opt/sqlbot/db_client/oracle_instant_client'
119123

120124
@field_validator('SQL_DEBUG',
@@ -123,6 +127,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
123127
'PARSE_REASONING_BLOCK_ENABLED',
124128
'PG_POOL_PRE_PING',
125129
'TABLE_EMBEDDING_ENABLED',
130+
'MULTI_TURN_EMBEDDING_ENABLED',
126131
mode='before')
127132
@classmethod
128133
def lowercase_bool(cls, v: Any) -> Any:

0 commit comments

Comments
 (0)