Skip to content

Commit 24b9c9a

Browse files
committed
feat: Vector retrieval matches tables
1 parent 41b0938 commit 24b9c9a

File tree

6 files changed

+138
-6
lines changed

6 files changed

+138
-6
lines changed

backend/apps/datasource/crud/datasource.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from sqlmodel import select
99

1010
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
11-
from apps.datasource.embedding.table_embedding import get_table_embedding
11+
from apps.datasource.embedding.table_embedding import get_table_embedding, calc_table_embedding
1212
from apps.datasource.utils.utils import aes_decrypt
1313
from apps.db.constant import DB
1414
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
1515
from apps.db.engine import get_engine_config, get_engine_conn
1616
from common.core.config import settings
1717
from common.core.deps import SessionDep, CurrentUser, Trans
18+
from common.utils.embedding_threads import run_save_table_embeddings
1819
from common.utils.utils import deepcopy_ignore_extra
1920
from .table import get_tables_by_ds_id
2021
from ..crud.field import delete_field_by_ds_id, update_field
@@ -194,6 +195,9 @@ def sync_table(session: SessionDep, ds: CoreDatasource, tables: List[CoreTable])
194195
session.query(CoreField).filter(CoreField.ds_id == ds.id).delete(synchronize_session=False)
195196
session.commit()
196197

198+
# do table embedding
199+
run_save_table_embeddings(id_list)
200+
197201

198202
def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, fields: List[ColumnSchema]):
199203
id_list = []
@@ -232,14 +236,23 @@ def update_table_and_fields(session: SessionDep, data: TableObj):
232236
for field in data.fields:
233237
update_field(session, field)
234238

239+
# do table embedding
240+
run_save_table_embeddings([data.table.id])
241+
235242

236243
def updateTable(session: SessionDep, table: CoreTable):
237244
update_table(session, table)
238245

246+
# do table embedding
247+
run_save_table_embeddings([table.id])
248+
239249

240250
def updateField(session: SessionDep, field: CoreField):
241251
update_field(session, field)
242252

253+
# do table embedding
254+
run_save_table_embeddings([field.table_id])
255+
243256

244257
def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj):
245258
ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first()
@@ -398,13 +411,13 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
398411
schema_table += ",\n".join(field_list)
399412
schema_table += '\n]\n'
400413

401-
t_obj = {"id": obj.table.id, "schema_table": schema_table}
414+
t_obj = {"id": obj.table.id, "schema_table": schema_table, "embedding": obj.table.embedding}
402415
tables.append(t_obj)
403416
all_tables.append(t_obj)
404417

405418
# do table embedding
406419
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
407-
tables = get_table_embedding(tables, question)
420+
tables = calc_table_embedding(tables, question)
408421
# splice schema
409422
if tables:
410423
for s in tables:

backend/apps/datasource/crud/table.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import traceback
2+
from typing import List
3+
4+
from sqlalchemy import and_, select, update
5+
6+
from apps.ai_model.embedding import EmbeddingModelCache
7+
from common.core.config import settings
18
from common.core.deps import SessionDep
2-
from ..models.datasource import CoreDatasource, CreateDatasource, CoreTable, CoreField, ColumnSchema
3-
from sqlalchemy import and_
9+
from ..models.datasource import CoreTable, CoreField
410

511

612
def delete_table_by_ds_id(session: SessionDep, id: int):
@@ -19,3 +25,65 @@ def update_table(session: SessionDep, item: CoreTable):
1925
record.custom_comment = item.custom_comment
2026
session.add(record)
2127
session.commit()
28+
29+
30+
def run_fill_empty_table_embedding(session: SessionDep):
31+
if not settings.EMBEDDING_ENABLED:
32+
return
33+
34+
stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None)))
35+
results = session.execute(stmt).scalars().all()
36+
37+
save_table_embedding(session, results)
38+
39+
40+
def save_table_embedding(session: SessionDep, ids: List[int]):
41+
if not settings.EMBEDDING_ENABLED:
42+
return
43+
44+
if not ids or len(ids) == 0:
45+
return
46+
try:
47+
48+
_list = session.query(CoreTable).filter(and_(CoreTable.id.in_(ids))).all()
49+
50+
table_schema = []
51+
for item in _list:
52+
fields = session.query(CoreField).filter(CoreField.table_id == item.id).all()
53+
54+
schema_table = ''
55+
schema_table += f"# Table: {item.table_name}"
56+
table_comment = ''
57+
if item.custom_comment:
58+
table_comment = item.custom_comment.strip()
59+
if table_comment == '':
60+
schema_table += '\n[\n'
61+
else:
62+
schema_table += f", {table_comment}\n[\n"
63+
64+
if fields:
65+
field_list = []
66+
for field in fields:
67+
field_comment = ''
68+
if field.custom_comment:
69+
field_comment = field.custom_comment.strip()
70+
if field_comment == '':
71+
field_list.append(f"({field.field_name}:{field.field_type})")
72+
else:
73+
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
74+
schema_table += ",\n".join(field_list)
75+
schema_table += '\n]\n'
76+
table_schema.append(schema_table)
77+
78+
model = EmbeddingModelCache.get_model()
79+
80+
results = model.embed_documents(table_schema)
81+
82+
for index in range(len(results)):
83+
item = results[index]
84+
stmt = update(CoreTable).where(and_(CoreTable.id == _list[index].id)).values(embedding=item)
85+
session.execute(stmt)
86+
session.commit()
87+
88+
except Exception:
89+
traceback.print_exc()

backend/apps/datasource/embedding/table_embedding.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,36 @@ def get_table_embedding(tables: list[dict], question: str):
3838
except Exception:
3939
traceback.print_exc()
4040
return _list
41+
42+
43+
def calc_table_embedding(tables: list[dict], question: str):
44+
_list = []
45+
for table in tables:
46+
_list.append(
47+
{"id": table.get('id'), "schema_table": table.get('schema_table'), "embedding": table.get('embedding'),
48+
"cosine_similarity": 0.0})
49+
50+
if _list:
51+
try:
52+
# text = [s.get('schema_table') for s in _list]
53+
#
54+
model = EmbeddingModelCache.get_model()
55+
# start_time = time.time()
56+
# results = model.embed_documents(text)
57+
# end_time = time.time()
58+
# SQLBotLogUtil.info(str(end_time - start_time))
59+
results = [item.get('embedding') for item in _list]
60+
61+
q_embedding = model.embed_query(question)
62+
for index in range(len(results)):
63+
item = results[index]
64+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
65+
66+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
67+
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
68+
# print(len(_list))
69+
SQLBotLogUtil.info(json.dumps(_list))
70+
return _list
71+
except Exception:
72+
traceback.print_exc()
73+
return _list

backend/apps/datasource/models/datasource.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22
from typing import List, Optional
33

4+
from pgvector.sqlalchemy import VECTOR
45
from pydantic import BaseModel
56
from sqlalchemy import Column, Text, BigInteger, DateTime, Identity
67
from sqlalchemy.dialects.postgresql import JSONB
@@ -31,6 +32,7 @@ class CoreTable(SQLModel, table=True):
3132
table_name: str = Field(sa_column=Column(Text))
3233
table_comment: str = Field(sa_column=Column(Text))
3334
custom_comment: str = Field(sa_column=Column(Text))
35+
embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
3436

3537

3638
class CoreField(SQLModel, table=True):

backend/common/utils/embedding_threads.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,13 @@ def run_save_data_training_embeddings(ids: List[int]):
3131
def fill_empty_data_training_embeddings():
3232
from apps.data_training.curd.data_training import run_fill_empty_embeddings
3333
executor.submit(run_fill_empty_embeddings, session)
34+
35+
36+
def run_save_table_embeddings(ids: List[int]):
37+
from apps.datasource.crud.table import save_table_embedding
38+
executor.submit(save_table_embedding, session, ids)
39+
40+
41+
def fill_empty_table_embeddings():
42+
from apps.datasource.crud.table import run_fill_empty_table_embedding
43+
executor.submit(run_fill_empty_table_embedding, session)

backend/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from common.core.config import settings
1919
from common.core.response_middleware import ResponseMiddleware, exception_handler
2020
from common.core.sqlbot_cache import init_sqlbot_cache
21-
from common.utils.embedding_threads import fill_empty_terminology_embeddings, fill_empty_data_training_embeddings
21+
from common.utils.embedding_threads import fill_empty_terminology_embeddings, fill_empty_data_training_embeddings, \
22+
fill_empty_table_embeddings
2223
from common.utils.utils import SQLBotLogUtil
2324

2425

@@ -35,13 +36,18 @@ def init_data_training_embedding_data():
3536
fill_empty_data_training_embeddings()
3637

3738

39+
def init_table_embedding():
40+
fill_empty_table_embeddings()
41+
42+
3843
@asynccontextmanager
3944
async def lifespan(app: FastAPI):
4045
run_migrations()
4146
init_sqlbot_cache()
4247
init_dynamic_cors(app)
4348
init_terminology_embedding_data()
4449
init_data_training_embedding_data()
50+
init_table_embedding()
4551
SQLBotLogUtil.info("✅ SQLBot 初始化完成")
4652
await sqlbot_xpack.core.clean_xpack_cache()
4753
await async_model_info() # 异步加密已有模型的密钥和地址

0 commit comments

Comments
 (0)