Skip to content

Commit b1dda8b

Browse files
committed
feat: Vector retrieval matches tables
1 parent c745eaa commit b1dda8b

File tree

8 files changed

+151
-43
lines changed

8 files changed

+151
-43
lines changed

backend/alembic/versions/047_table_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
def upgrade():
2121
# ### commands auto generated by Alembic - please adjust! ###
2222
op.add_column('core_table', sa.Column('embedding', sa.Text(), nullable=True))
23+
op.add_column('core_datasource', sa.Column('embedding', sa.Text(), nullable=True))
2324
# ### end Alembic commands ###
2425

2526

2627
def downgrade():
2728
# ### commands auto generated by Alembic - please adjust! ###
2829
op.drop_column('core_table', 'embedding')
30+
op.drop_column('core_datasource', 'embedding')
2931
# ### end Alembic commands ###

backend/apps/datasource/crud/datasource.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from apps.db.engine import get_engine_config, get_engine_conn
1616
from common.core.config import settings
1717
from common.core.deps import SessionDep, CurrentUser, Trans
18-
from common.utils.embedding_threads import run_save_table_embeddings
18+
from common.utils.embedding_threads import run_save_table_embeddings, run_save_ds_embeddings
1919
from common.utils.utils import deepcopy_ignore_extra
2020
from .table import get_tables_by_ds_id
2121
from ..crud.field import delete_field_by_ds_id, update_field
@@ -105,6 +105,8 @@ def update_ds(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreData
105105
setattr(record, field, value)
106106
session.add(record)
107107
session.commit()
108+
109+
run_save_ds_embeddings([ds.id])
108110
return ds
109111

110112

@@ -197,6 +199,7 @@ def sync_table(session: SessionDep, ds: CoreDatasource, tables: List[CoreTable])
197199

198200
# do table embedding
199201
run_save_table_embeddings(id_list)
202+
run_save_ds_embeddings([ds.id])
200203

201204

202205
def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, fields: List[ColumnSchema]):
@@ -238,20 +241,23 @@ def update_table_and_fields(session: SessionDep, data: TableObj):
238241

239242
# do table embedding
240243
run_save_table_embeddings([data.table.id])
244+
run_save_ds_embeddings([data.table.ds_id])
241245

242246

243247
def updateTable(session: SessionDep, table: CoreTable):
244248
update_table(session, table)
245249

246250
# do table embedding
247251
run_save_table_embeddings([table.id])
252+
run_save_ds_embeddings([table.ds_id])
248253

249254

250255
def updateField(session: SessionDep, field: CoreField):
251256
update_field(session, field)
252257

253258
# do table embedding
254259
run_save_table_embeddings([field.table_id])
260+
run_save_ds_embeddings([field.ds_id])
255261

256262

257263
def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj):

backend/apps/datasource/crud/table.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from common.core.config import settings
1010
from common.core.deps import SessionDep
1111
from common.utils.utils import SQLBotLogUtil
12-
from ..models.datasource import CoreTable, CoreField
12+
from ..models.datasource import CoreTable, CoreField, CoreDatasource
1313

1414

1515
def delete_table_by_ds_id(session: SessionDep, id: int):
@@ -30,18 +30,24 @@ def update_table(session: SessionDep, item: CoreTable):
3030
session.commit()
3131

3232

33-
def run_fill_empty_table_embedding(session_maker):
33+
def run_fill_empty_table_and_ds_embedding(session_maker):
3434
try:
3535
if not settings.TABLE_EMBEDDING_ENABLED:
3636
return
3737

38-
SQLBotLogUtil.info('get tables')
3938
session = session_maker()
39+
40+
SQLBotLogUtil.info('get tables')
4041
stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None)))
4142
results = session.execute(stmt).scalars().all()
42-
SQLBotLogUtil.info('result: ' + str(len(results)))
43-
44-
save_table_embedding(session, results)
43+
SQLBotLogUtil.info('table result: ' + str(len(results)))
44+
save_table_embedding(session_maker, results)
45+
46+
SQLBotLogUtil.info('get datasource')
47+
ds_stmt = select(CoreDatasource.id).where(and_(CoreDatasource.embedding.is_(None)))
48+
ds_results = session.execute(ds_stmt).scalars().all()
49+
SQLBotLogUtil.info('datasource result: ' + str(len(ds_results)))
50+
save_ds_embedding(session_maker, ds_results)
4551
except Exception:
4652
traceback.print_exc()
4753
finally:
@@ -98,3 +104,58 @@ def save_table_embedding(session_maker, ids: List[int]):
98104
traceback.print_exc()
99105
finally:
100106
session_maker.remove()
107+
108+
109+
def save_ds_embedding(session_maker, ids: List[int]):
110+
if not settings.TABLE_EMBEDDING_ENABLED:
111+
return
112+
113+
if not ids or len(ids) == 0:
114+
return
115+
try:
116+
SQLBotLogUtil.info('start datasource embedding')
117+
start_time = time.time()
118+
model = EmbeddingModelCache.get_model()
119+
session = session_maker()
120+
for _id in ids:
121+
schema_table = ''
122+
ds = session.query(CoreDatasource).filter(CoreDatasource.id == _id).first()
123+
schema_table += f"{ds.name}, {ds.description}\n"
124+
tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all()
125+
for table in tables:
126+
fields = session.query(CoreField).filter(CoreField.table_id == table.id).all()
127+
128+
schema_table += f"# Table: {table.table_name}"
129+
table_comment = ''
130+
if table.custom_comment:
131+
table_comment = table.custom_comment.strip()
132+
if table_comment == '':
133+
schema_table += '\n[\n'
134+
else:
135+
schema_table += f", {table_comment}\n[\n"
136+
137+
if fields:
138+
field_list = []
139+
for field in fields:
140+
field_comment = ''
141+
if field.custom_comment:
142+
field_comment = field.custom_comment.strip()
143+
if field_comment == '':
144+
field_list.append(f"({field.field_name}:{field.field_type})")
145+
else:
146+
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
147+
schema_table += ",\n".join(field_list)
148+
schema_table += '\n]\n'
149+
# table_schema.append(schema_table)
150+
emb = json.dumps(model.embed_query(schema_table))
151+
152+
stmt = update(CoreDatasource).where(and_(CoreDatasource.id == _id)).values(embedding=emb)
153+
session.execute(stmt)
154+
session.commit()
155+
156+
end_time = time.time()
157+
SQLBotLogUtil.info('datasource embedding finished in: ' + str(end_time - start_time) + ' seconds')
158+
except Exception:
159+
traceback.print_exc()
160+
finally:
161+
session_maker.remove()
Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Author: Junjun
22
# Date: 2025/9/18
33
import json
4+
import time
45
import traceback
56
from typing import Optional
67

78
from apps.ai_model.embedding import EmbeddingModelCache
8-
from apps.datasource.crud.datasource import get_table_schema
99
from apps.datasource.embedding.utils import cosine_similarity
1010
from apps.datasource.models.datasource import CoreDatasource
1111
from apps.system.crud.assistant import AssistantOutDs
@@ -18,42 +18,71 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
1818
question: str,
1919
current_assistant: Optional[CurrentAssistant] = None):
2020
_list = []
21-
if current_assistant and current_assistant.type != 4:
21+
if current_assistant and current_assistant.type == 1:
2222
if out_ds.ds_list:
2323
for _ds in out_ds.ds_list:
2424
ds = out_ds.get_ds(_ds.id)
2525
table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False)
2626
ds_info = f"{ds.name}, {ds.description}\n"
2727
ds_schema = ds_info + table_schema
2828
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
29+
30+
if _list:
31+
try:
32+
text = [s.get('ds_schema') for s in _list]
33+
34+
model = EmbeddingModelCache.get_model()
35+
results = model.embed_documents(text)
36+
37+
q_embedding = model.embed_query(question)
38+
for index in range(len(results)):
39+
item = results[index]
40+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
41+
42+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
43+
# print(len(_list))
44+
SQLBotLogUtil.info(json.dumps(
45+
[{"id": ele.get("id"), "name": ele.get("ds").name,
46+
"cosine_similarity": ele.get("cosine_similarity")}
47+
for ele in _list]))
48+
ds = _list[0].get('ds')
49+
return {"id": ds.id, "name": ds.name, "description": ds.description}
50+
except Exception:
51+
traceback.print_exc()
2952
else:
3053
for _ds in _ds_list:
3154
if _ds.get('id'):
3255
ds = session.get(CoreDatasource, _ds.get('id'))
33-
table_schema = get_table_schema(session, current_user, ds, question, embedding=False)
34-
ds_info = f"{ds.name}, {ds.description}\n"
35-
ds_schema = ds_info + table_schema
36-
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
56+
# table_schema = get_table_schema(session, current_user, ds, question, embedding=False)
57+
# ds_info = f"{ds.name}, {ds.description}\n"
58+
# ds_schema = ds_info + table_schema
59+
_list.append({"id": ds.id, "cosine_similarity": 0.0, "ds": ds, "embedding": ds.embedding})
60+
61+
if _list:
62+
try:
63+
# text = [s.get('ds_schema') for s in _list]
64+
65+
model = EmbeddingModelCache.get_model()
66+
start_time = time.time()
67+
# results = model.embed_documents(text)
68+
results = [item.get('embedding') for item in _list]
69+
70+
q_embedding = model.embed_query(question)
71+
for index in range(len(results)):
72+
item = results[index]
73+
if item:
74+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
3775

38-
if _list:
39-
try:
40-
text = [s.get('ds_schema') for s in _list]
41-
42-
model = EmbeddingModelCache.get_model()
43-
results = model.embed_documents(text)
44-
45-
q_embedding = model.embed_query(question)
46-
for index in range(len(results)):
47-
item = results[index]
48-
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
49-
50-
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
51-
# print(len(_list))
52-
SQLBotLogUtil.info(json.dumps(
53-
[{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")}
54-
for ele in _list]))
55-
ds = _list[0].get('ds')
56-
return {"id": ds.id, "name": ds.name, "description": ds.description}
57-
except Exception:
58-
traceback.print_exc()
76+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
77+
# print(len(_list))
78+
end_time = time.time()
79+
SQLBotLogUtil.info(str(end_time - start_time))
80+
SQLBotLogUtil.info(json.dumps(
81+
[{"id": ele.get("id"), "name": ele.get("ds").name,
82+
"cosine_similarity": ele.get("cosine_similarity")}
83+
for ele in _list]))
84+
ds = _list[0].get('ds')
85+
return {"id": ds.id, "name": ds.name, "description": ds.description}
86+
except Exception:
87+
traceback.print_exc()
5988
return _list

backend/apps/datasource/embedding/table_embedding.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def calc_table_embedding(tables: list[dict], question: str):
5252
# text = [s.get('schema_table') for s in _list]
5353
#
5454
model = EmbeddingModelCache.get_model()
55-
# start_time = time.time()
55+
start_time = time.time()
5656
# results = model.embed_documents(text)
5757
# end_time = time.time()
5858
# SQLBotLogUtil.info(str(end_time - start_time))
@@ -67,7 +67,11 @@ def calc_table_embedding(tables: list[dict], question: str):
6767
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
6868
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
6969
# print(len(_list))
70-
SQLBotLogUtil.info(json.dumps(_list))
70+
end_time = time.time()
71+
SQLBotLogUtil.info(str(end_time - start_time))
72+
SQLBotLogUtil.info(json.dumps([{"id": ele.get('id'), "schema_table": ele.get('schema_table'),
73+
"cosine_similarity": ele.get('cosine_similarity')}
74+
for ele in _list]))
7175
return _list
7276
except Exception:
7377
traceback.print_exc()

backend/apps/datasource/models/datasource.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class CoreDatasource(SQLModel, table=True):
2121
num: str = Field(max_length=256, nullable=True)
2222
oid: int = Field(sa_column=Column(BigInteger()))
2323
table_relation: List = Field(sa_column=Column(JSONB, nullable=True))
24+
embedding: str = Field(sa_column=Column(Text, nullable=True))
2425

2526

2627
class CoreTable(SQLModel, table=True):

backend/common/utils/embedding_threads.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def run_save_table_embeddings(ids: List[int]):
3838
executor.submit(save_table_embedding, session_maker, ids)
3939

4040

41-
def fill_empty_table_embeddings():
42-
from apps.datasource.crud.table import run_fill_empty_table_embedding
43-
executor.submit(run_fill_empty_table_embedding, session_maker)
41+
def run_save_ds_embeddings(ids: List[int]):
42+
from apps.datasource.crud.table import save_ds_embedding
43+
executor.submit(save_ds_embedding, session_maker, ids)
44+
45+
46+
def fill_empty_table_and_ds_embeddings():
47+
from apps.datasource.crud.table import run_fill_empty_table_and_ds_embedding
48+
executor.submit(run_fill_empty_table_and_ds_embedding, session_maker)

backend/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from alembic import command
1414
from apps.api import api_router
15-
from common.utils.embedding_threads import fill_empty_table_embeddings
15+
from common.utils.embedding_threads import fill_empty_table_and_ds_embeddings
1616
from apps.system.crud.aimodel_manage import async_model_info
1717
from apps.system.crud.assistant import init_dynamic_cors
1818
from apps.system.middleware.auth import TokenMiddleware
@@ -36,8 +36,8 @@ def init_data_training_embedding_data():
3636
fill_empty_data_training_embeddings()
3737

3838

39-
def init_table_embedding():
40-
fill_empty_table_embeddings()
39+
def init_table_and_ds_embedding():
40+
fill_empty_table_and_ds_embeddings()
4141

4242

4343
@asynccontextmanager
@@ -47,7 +47,7 @@ async def lifespan(app: FastAPI):
4747
init_dynamic_cors(app)
4848
init_terminology_embedding_data()
4949
init_data_training_embedding_data()
50-
init_table_embedding()
50+
init_table_and_ds_embedding()
5151
SQLBotLogUtil.info("✅ SQLBot 初始化完成")
5252
await sqlbot_xpack.core.clean_xpack_cache()
5353
await async_model_info() # 异步加密已有模型的密钥和地址

0 commit comments

Comments
 (0)