Skip to content

Commit fd9a59a

Browse files
WainWongclaude
andcommitted
feat(table-selection): add LLM-based table selection for SQL generation
Add a new LLM-based table selection feature that can replace or work alongside the existing RAG-based table embedding. Key changes: - New table selection module (backend/apps/datasource/llm_select/) - New config option TABLE_LLM_SELECTION_ENABLED (default: true) - Add table_select_answer field to ChatRecord for logging LLM selections - Add SELECT_TABLE operation type for tracking in chat logs - Skip foreign key relation table completion when using LLM selection (LLM already sees table relations and can decide which tables to include) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent a6dcec0 commit fd9a59a

File tree

11 files changed

+456
-12
lines changed

11 files changed

+456
-12
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""add table_select_answer column to chat_record
2+
3+
Revision ID: 054_table_select
4+
Revises: 5755c0b95839
5+
Create Date: 2025-12-23
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
revision = '054_table_select'
12+
down_revision = '5755c0b95839'
13+
branch_labels = None
14+
depends_on = None
15+
16+
17+
def upgrade():
18+
op.add_column('chat_record', sa.Column('table_select_answer', sa.Text(), nullable=True))
19+
20+
21+
def downgrade():
22+
op.drop_column('chat_record', 'table_select_answer')

backend/apps/chat/curd/chat.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,26 @@ def save_select_datasource_answer(session: SessionDep, record_id: int, answer: s
649649
return result
650650

651651

652+
def save_table_select_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
653+
"""保存 LLM 表选择的结果到 ChatRecord"""
654+
if not record_id:
655+
raise Exception("Record id cannot be None")
656+
record = get_chat_record_by_id(session, record_id)
657+
658+
record.table_select_answer = answer
659+
660+
result = ChatRecord(**record.model_dump())
661+
662+
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
663+
table_select_answer=record.table_select_answer,
664+
)
665+
666+
session.execute(stmt)
667+
session.commit()
668+
669+
return result
670+
671+
652672
def save_recommend_question_answer(session: SessionDep, record_id: int,
653673
answer: dict = None) -> ChatRecord:
654674
if not record_id:

backend/apps/chat/models/chat_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class OperationEnum(Enum):
4040
GENERATE_SQL_WITH_PERMISSIONS = '5'
4141
CHOOSE_DATASOURCE = '6'
4242
GENERATE_DYNAMIC_SQL = '7'
43+
SELECT_TABLE = '8' # LLM 表选择
4344

4445

4546
class ChatFinishStep(Enum):
@@ -112,6 +113,7 @@ class ChatRecord(SQLModel, table=True):
112113
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
113114
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
114115
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
116+
table_select_answer: str = Field(sa_column=Column(Text, nullable=True))
115117
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
116118
error: str = Field(sa_column=Column(Text, nullable=True))
117119
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
@@ -137,6 +139,7 @@ class ChatRecordResult(BaseModel):
137139
predict_data: Optional[str] = None
138140
recommended_question: Optional[str] = None
139141
datasource_select_answer: Optional[str] = None
142+
table_select_answer: Optional[str] = None
140143
finish: Optional[bool] = None
141144
error: Optional[str] = None
142145
analysis_record_id: Optional[int] = None

backend/apps/chat/task/llm.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
116116
if not ds:
117117
raise SingleMessageError("No available datasource configuration found")
118118
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
119-
chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds,
120-
question=chat_question.question, embedding=embedding)
119+
# 延迟 get_table_schema 调用到 init_record 之后,以便记录 LLM 表选择日志
120+
self._pending_schema_params = {
121+
'session': session,
122+
'current_user': current_user,
123+
'ds': ds,
124+
'question': chat_question.question,
125+
'embedding': embedding,
126+
'history_questions': history_questions,
127+
'config': config
128+
}
121129

122130
self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
123131
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)
@@ -224,6 +232,22 @@ def init_messages(self):
224232

225233
def init_record(self, session: Session) -> ChatRecord:
226234
self.record = save_question(session=session, current_user=self.current_user, question=self.chat_question)
235+
236+
# 如果有延迟的 schema 获取,现在执行(此时 record 已存在,可以记录 LLM 表选择日志)
237+
if hasattr(self, '_pending_schema_params') and self._pending_schema_params:
238+
params = self._pending_schema_params
239+
self.chat_question.db_schema = get_table_schema(
240+
session=params['session'],
241+
current_user=params['current_user'],
242+
ds=params['ds'],
243+
question=params['question'],
244+
embedding=params['embedding'],
245+
history_questions=params['history_questions'],
246+
config=params['config'],
247+
record_id=self.record.id
248+
)
249+
self._pending_schema_params = None
250+
227251
return self.record
228252

229253
def get_record(self):
@@ -349,7 +373,9 @@ def generate_recommend_questions_task(self, _session: Session):
349373
session=_session,
350374
current_user=self.current_user, ds=self.ds,
351375
question=self.chat_question.question,
352-
embedding=False)
376+
embedding=False,
377+
config=self.config,
378+
record_id=self.record.id)
353379

354380
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
355381
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
@@ -494,7 +520,9 @@ def select_datasource(self, _session: Session):
494520
self.ds)
495521
self.chat_question.db_schema = get_table_schema(session=_session,
496522
current_user=self.current_user, ds=self.ds,
497-
question=self.chat_question.question)
523+
question=self.chat_question.question,
524+
config=self.config,
525+
record_id=self.record.id)
498526
_engine_type = self.chat_question.engine
499527
_chat.engine_type = _ds.type_name
500528
# save chat
@@ -997,7 +1025,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
9971025
session=_session,
9981026
current_user=self.current_user,
9991027
ds=self.ds,
1000-
question=self.chat_question.question)
1028+
question=self.chat_question.question,
1029+
config=self.config,
1030+
record_id=self.record.id)
10011031
else:
10021032
self.validate_history_ds(_session)
10031033

backend/apps/datasource/crud/datasource.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from sqlbot_xpack.permissions.models.ds_rules import DsRules
88
from sqlmodel import select
99

10+
from apps.ai_model.model_factory import LLMConfig
1011
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
1112
from apps.datasource.embedding.table_embedding import calc_table_embedding
13+
from apps.datasource.llm_select.table_selection import calc_table_llm_selection
1214
from apps.datasource.utils.utils import aes_decrypt
1315
from apps.db.constant import DB
1416
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
@@ -416,7 +418,8 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
416418

417419

418420
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
419-
embedding: bool = True) -> str:
421+
embedding: bool = True, history_questions: List[str] = None,
422+
config: LLMConfig = None, lang: str = "中文", record_id: int = None) -> str:
420423
schema_str = ""
421424
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
422425
if len(table_objs) == 0:
@@ -425,7 +428,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
425428
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
426429
tables = []
427430
all_tables = [] # temp save all tables
431+
432+
# 构建 table_name -> table_obj 映射,用于 LLM 表选择
433+
table_name_to_obj = {}
428434
for obj in table_objs:
435+
table_name_to_obj[obj.table.table_name] = obj
436+
429437
schema_table = ''
430438
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
431439
table_comment = ''
@@ -453,16 +461,36 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
453461
tables.append(t_obj)
454462
all_tables.append(t_obj)
455463

456-
# do table embedding
457-
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
458-
tables = calc_table_embedding(tables, question)
464+
# do table selection
465+
used_llm_selection = False # 标记是否使用了 LLM 表选择
466+
if embedding and tables:
467+
if settings.TABLE_LLM_SELECTION_ENABLED and config:
468+
# 使用 LLM 表选择
469+
selected_table_names = calc_table_llm_selection(
470+
config=config,
471+
table_objs=table_objs,
472+
question=question,
473+
ds_table_relation=ds.table_relation,
474+
history_questions=history_questions,
475+
lang=lang,
476+
session=session,
477+
record_id=record_id
478+
)
479+
if selected_table_names:
480+
# 根据选中的表名筛选 tables
481+
selected_table_ids = [table_name_to_obj[name].table.id for name in selected_table_names if name in table_name_to_obj]
482+
tables = [t for t in tables if t.get('id') in selected_table_ids]
483+
used_llm_selection = True # LLM 成功选择了表
484+
elif settings.TABLE_EMBEDDING_ENABLED:
485+
# 使用 RAG 表选择
486+
tables = calc_table_embedding(tables, question, history_questions)
459487
# splice schema
460488
if tables:
461489
for s in tables:
462490
schema_str += s.get('schema_table')
463491

464-
# field relation
465-
if tables and ds.table_relation:
492+
# field relation - LLM 表选择模式下不补全关联表,完全信任 LLM 的选择结果
493+
if tables and ds.table_relation and not used_llm_selection:
466494
relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
467495
if relations:
468496
# Complete the missing table
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Author: SQLBot
2+
# Date: 2025/12/23

0 commit comments

Comments
 (0)