Skip to content

Commit 487c8e3

Browse files
committed
feat: Vector retrieval matches tables
1 parent 5f77b82 commit 487c8e3

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

backend/apps/chat/task/llm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
114114
if not ds:
115115
raise SingleMessageError("No available datasource configuration found")
116116
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
117-
chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds)
117+
chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds,
118+
question=self.chat_question.question)
118119

119120
self.generate_sql_logs = list_generate_sql_logs(session=self.session, chart_id=chat_id)
120121
self.generate_chart_logs = list_generate_chart_logs(session=self.session, chart_id=chat_id)
@@ -346,7 +347,8 @@ def generate_recommend_questions_task(self):
346347
if self.ds and not self.chat_question.db_schema:
347348
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
348349
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
349-
current_user=self.current_user, ds=self.ds)
350+
current_user=self.current_user, ds=self.ds,
351+
question=self.chat_question.question)
350352

351353
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
352354
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question()))
@@ -506,7 +508,8 @@ def select_datasource(self):
506508
self.chat_question.engine = (_ds.type_name if _ds.type != 'excel' else 'PostgreSQL') + get_version(
507509
self.ds)
508510
self.chat_question.db_schema = get_table_schema(session=self.session,
509-
current_user=self.current_user, ds=self.ds)
511+
current_user=self.current_user, ds=self.ds,
512+
question=self.chat_question.question)
510513
_engine_type = self.chat_question.engine
511514
_chat.engine_type = _ds.type_name
512515
# save chat
@@ -995,7 +998,8 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
995998
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
996999
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
9971000
current_user=self.current_user,
998-
ds=self.ds)
1001+
ds=self.ds,
1002+
question=self.chat_question.question)
9991003
else:
10001004
self.validate_history_ds()
10011005

backend/apps/datasource/crud/datasource.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .table import get_tables_by_ds_id
1717
from ..crud.field import delete_field_by_ds_id, update_field
1818
from ..crud.table import delete_table_by_ds_id, update_table
19+
from ..embedding.ds_embedding import get_table_embedding
1920
from ..models.datasource import CoreDatasource, CreateDatasource, CoreTable, CoreField, ColumnSchema, TableObj, \
2021
DatasourceConf, TableAndFields
2122

@@ -344,22 +345,25 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
344345
return _list
345346

346347

347-
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> str:
348+
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
349+
embedding: bool = True) -> str:
348350
schema_str = ""
349351
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
350352
if len(table_objs) == 0:
351353
return schema_str
352354
db_name = table_objs[0].schema
353355
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
356+
tables = []
354357
for obj in table_objs:
355-
schema_str += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
358+
schema_table = ''
359+
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
356360
table_comment = ''
357361
if obj.table.custom_comment:
358362
table_comment = obj.table.custom_comment.strip()
359363
if table_comment == '':
360-
schema_str += '\n[\n'
364+
schema_table += '\n[\n'
361365
else:
362-
schema_str += f", {table_comment}\n[\n"
366+
schema_table += f", {table_comment}\n[\n"
363367

364368
field_list = []
365369
for field in obj.fields:
@@ -370,7 +374,14 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
370374
field_list.append(f"({field.field_name}:{field.field_type})")
371375
else:
372376
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
373-
schema_str += ",\n".join(field_list)
374-
schema_str += '\n]\n'
377+
schema_table += ",\n".join(field_list)
378+
schema_table += '\n]\n'
379+
tables.append(schema_table)
380+
# do table embedding
381+
if embedding:
382+
tables = get_table_embedding(session, current_user, tables, question)
383+
375384
# todo 外键
385+
for s in tables:
386+
schema_str += s
376387
return schema_str

backend/apps/datasource/embedding/ds_embedding.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from apps.datasource.crud.datasource import get_table_schema
1010
from apps.datasource.models.datasource import CoreDatasource
1111
from apps.system.crud.assistant import AssistantOutDs
12+
from common.core.config import settings
1213
from common.core.deps import CurrentAssistant
1314
from common.core.deps import SessionDep, CurrentUser
1415
from common.utils.utils import SQLBotLogUtil
@@ -45,7 +46,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
4546
for _ds in _ds_list:
4647
if _ds.get('id'):
4748
ds = session.get(CoreDatasource, _ds.get('id'))
48-
table_schema = get_table_schema(session, current_user, ds)
49+
table_schema = get_table_schema(session, current_user, ds, embedding=False)
4950
ds_info = f"{ds.name}, {ds.description}\n"
5051
ds_schema = ds_info + table_schema
5152
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
@@ -71,3 +72,31 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
7172
return {"id": ds.id, "name": ds.name, "description": ds.description}
7273
except Exception:
7374
traceback.print_exc()
75+
return _list
76+
77+
78+
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str):
79+
_list = []
80+
for table_schema in tables:
81+
_list.append({"table_schema": table_schema, "cosine_similarity": 0.0})
82+
83+
if _list:
84+
try:
85+
text = [s.get('table_schema') for s in _list]
86+
87+
model = EmbeddingModelCache.get_model()
88+
results = model.embed_documents(text)
89+
90+
q_embedding = model.embed_query(question)
91+
for index in range(len(results)):
92+
item = results[index]
93+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
94+
95+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
96+
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
97+
# print(len(_list))
98+
SQLBotLogUtil.info(json.dumps(_list))
99+
return [t.get("table_schema") for t in _list]
100+
except Exception:
101+
traceback.print_exc()
102+
return _list

backend/common/core/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
101101
PG_POOL_RECYCLE: int = 3600
102102
PG_POOL_PRE_PING: bool = True
103103

104+
TABLE_EMBEDDING_COUNT: int = 30
105+
104106

105107
settings = Settings() # type: ignore

0 commit comments

Comments
 (0)