Skip to content

Commit 5a2c049

Browse files
Merge branch 'main' into pr@main@feat_embedding
2 parents 476743b + 5acb04a commit 5a2c049

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

backend/apps/datasource/crud/datasource.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import datetime
22
import json
3+
import traceback
34
from typing import List, Optional
45

56
from fastapi import HTTPException
67
from sqlalchemy import and_, text
78
from sqlmodel import select
89

10+
from apps.ai_model.embedding import EmbeddingModelCache
911
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
1012
from apps.datasource.embedding.table_embedding import get_table_embedding
1113
from apps.datasource.utils.utils import aes_decrypt
1214
from apps.db.constant import DB
1315
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
1416
from apps.db.engine import get_engine_config, get_engine_conn
17+
from common.core.config import settings
1518
from common.core.deps import SessionDep, CurrentUser, Trans
19+
from common.utils.utils import SQLBotLogUtil
1620
from common.utils.utils import deepcopy_ignore_extra
1721
from .table import get_tables_by_ds_id
1822
from ..crud.field import delete_field_by_ds_id, update_field
@@ -385,3 +389,30 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
385389
for s in tables:
386390
schema_str += s
387391
return schema_str
392+
393+
394+
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str):
395+
_list = []
396+
for table_schema in tables:
397+
_list.append({"table_schema": table_schema, "cosine_similarity": 0.0})
398+
399+
if _list:
400+
try:
401+
text = [s.get('table_schema') for s in _list]
402+
403+
model = EmbeddingModelCache.get_model()
404+
results = model.embed_documents(text)
405+
406+
q_embedding = model.embed_query(question)
407+
for index in range(len(results)):
408+
item = results[index]
409+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
410+
411+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
412+
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
413+
# print(len(_list))
414+
SQLBotLogUtil.info(json.dumps(_list))
415+
return [t.get("table_schema") for t in _list]
416+
except Exception:
417+
traceback.print_exc()
418+
return _list

0 commit comments

Comments
 (0)