Skip to content

Commit 5acb04a

Browse files
committed
feat: Vector retrieval matches tables
1 parent 487c8e3 commit 5acb04a

File tree

2 files changed

+32
-29
lines changed

2 files changed

+32
-29
lines changed

backend/apps/datasource/crud/datasource.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
import datetime
22
import json
3+
import traceback
34
from typing import List, Optional
45

56
from fastapi import HTTPException
67
from sqlalchemy import and_, text
78
from sqlmodel import select
89

10+
from apps.ai_model.embedding import EmbeddingModelCache
911
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
12+
from apps.datasource.embedding.ds_embedding import cosine_similarity
1013
from apps.datasource.utils.utils import aes_decrypt
1114
from apps.db.constant import DB
1215
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
1316
from apps.db.engine import get_engine_config, get_engine_conn
17+
from common.core.config import settings
1418
from common.core.deps import SessionDep, CurrentUser, Trans
19+
from common.utils.utils import SQLBotLogUtil
1520
from common.utils.utils import deepcopy_ignore_extra
1621
from .table import get_tables_by_ds_id
1722
from ..crud.field import delete_field_by_ds_id, update_field
1823
from ..crud.table import delete_table_by_ds_id, update_table
19-
from ..embedding.ds_embedding import get_table_embedding
2024
from ..models.datasource import CoreDatasource, CreateDatasource, CoreTable, CoreField, ColumnSchema, TableObj, \
2125
DatasourceConf, TableAndFields
2226

@@ -385,3 +389,30 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
385389
for s in tables:
386390
schema_str += s
387391
return schema_str
392+
393+
394+
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str):
395+
_list = []
396+
for table_schema in tables:
397+
_list.append({"table_schema": table_schema, "cosine_similarity": 0.0})
398+
399+
if _list:
400+
try:
401+
text = [s.get('table_schema') for s in _list]
402+
403+
model = EmbeddingModelCache.get_model()
404+
results = model.embed_documents(text)
405+
406+
q_embedding = model.embed_query(question)
407+
for index in range(len(results)):
408+
item = results[index]
409+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
410+
411+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
412+
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
413+
# print(len(_list))
414+
SQLBotLogUtil.info(json.dumps(_list))
415+
return [t.get("table_schema") for t in _list]
416+
except Exception:
417+
traceback.print_exc()
418+
return _list

backend/apps/datasource/embedding/ds_embedding.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from apps.datasource.crud.datasource import get_table_schema
1010
from apps.datasource.models.datasource import CoreDatasource
1111
from apps.system.crud.assistant import AssistantOutDs
12-
from common.core.config import settings
1312
from common.core.deps import CurrentAssistant
1413
from common.core.deps import SessionDep, CurrentUser
1514
from common.utils.utils import SQLBotLogUtil
@@ -73,30 +72,3 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
7372
except Exception:
7473
traceback.print_exc()
7574
return _list
76-
77-
78-
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str):
79-
_list = []
80-
for table_schema in tables:
81-
_list.append({"table_schema": table_schema, "cosine_similarity": 0.0})
82-
83-
if _list:
84-
try:
85-
text = [s.get('table_schema') for s in _list]
86-
87-
model = EmbeddingModelCache.get_model()
88-
results = model.embed_documents(text)
89-
90-
q_embedding = model.embed_query(question)
91-
for index in range(len(results)):
92-
item = results[index]
93-
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
94-
95-
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
96-
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
97-
# print(len(_list))
98-
SQLBotLogUtil.info(json.dumps(_list))
99-
return [t.get("table_schema") for t in _list]
100-
except Exception:
101-
traceback.print_exc()
102-
return _list

0 commit comments

Comments
 (0)