Skip to content

Commit eb82d4f

Browse files
committed
feat: Vector retrieval matches tables
1 parent 5acb04a commit eb82d4f

File tree

4 files changed

+58
-49
lines changed

4 files changed

+58
-49
lines changed

backend/apps/datasource/crud/datasource.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import datetime
22
import json
3-
import traceback
43
from typing import List, Optional
54

65
from fastapi import HTTPException
76
from sqlalchemy import and_, text
87
from sqlmodel import select
98

10-
from apps.ai_model.embedding import EmbeddingModelCache
119
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
12-
from apps.datasource.embedding.ds_embedding import cosine_similarity
10+
from apps.datasource.embedding.table_embedding import get_table_embedding
1311
from apps.datasource.utils.utils import aes_decrypt
1412
from apps.db.constant import DB
1513
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
1614
from apps.db.engine import get_engine_config, get_engine_conn
17-
from common.core.config import settings
1815
from common.core.deps import SessionDep, CurrentUser, Trans
19-
from common.utils.utils import SQLBotLogUtil
2016
from common.utils.utils import deepcopy_ignore_extra
2117
from .table import get_tables_by_ds_id
2218
from ..crud.field import delete_field_by_ds_id, update_field
@@ -389,30 +385,3 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
389385
for s in tables:
390386
schema_str += s
391387
return schema_str
392-
393-
394-
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str):
395-
_list = []
396-
for table_schema in tables:
397-
_list.append({"table_schema": table_schema, "cosine_similarity": 0.0})
398-
399-
if _list:
400-
try:
401-
text = [s.get('table_schema') for s in _list]
402-
403-
model = EmbeddingModelCache.get_model()
404-
results = model.embed_documents(text)
405-
406-
q_embedding = model.embed_query(question)
407-
for index in range(len(results)):
408-
item = results[index]
409-
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
410-
411-
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
412-
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
413-
# print(len(_list))
414-
SQLBotLogUtil.info(json.dumps(_list))
415-
return [t.get("table_schema") for t in _list]
416-
except Exception:
417-
traceback.print_exc()
418-
return _list

backend/apps/datasource/embedding/ds_embedding.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,19 @@
11
# Author: Junjun
22
# Date: 2025/9/18
33
import json
4-
import math
54
import traceback
65
from typing import Optional
76

87
from apps.ai_model.embedding import EmbeddingModelCache
98
from apps.datasource.crud.datasource import get_table_schema
9+
from apps.datasource.embedding.utils import cosine_similarity
1010
from apps.datasource.models.datasource import CoreDatasource
1111
from apps.system.crud.assistant import AssistantOutDs
1212
from common.core.deps import CurrentAssistant
1313
from common.core.deps import SessionDep, CurrentUser
1414
from common.utils.utils import SQLBotLogUtil
1515

1616

17-
def cosine_similarity(vec_a, vec_b):
18-
if len(vec_a) != len(vec_b):
19-
raise ValueError("The vector dimension must be the same")
20-
21-
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
22-
23-
norm_a = math.sqrt(sum(a * a for a in vec_a))
24-
norm_b = math.sqrt(sum(b * b for b in vec_b))
25-
26-
if norm_a == 0 or norm_b == 0:
27-
return 0.0
28-
29-
return dot_product / (norm_a * norm_b)
30-
31-
3217
def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, out_ds: AssistantOutDs,
3318
question: str,
3419
current_assistant: Optional[CurrentAssistant] = None):
@@ -45,7 +30,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
4530
for _ds in _ds_list:
4631
if _ds.get('id'):
4732
ds = session.get(CoreDatasource, _ds.get('id'))
48-
table_schema = get_table_schema(session, current_user, ds, embedding=False)
33+
table_schema = get_table_schema(session, current_user, ds, question, embedding=False)
4934
ds_info = f"{ds.name}, {ds.description}\n"
5035
ds_schema = ds_info + table_schema
5136
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Author: Junjun
2+
# Date: 2025/9/23
3+
import json
4+
import traceback
5+
6+
from apps.ai_model.embedding import EmbeddingModelCache
7+
from apps.datasource.embedding.utils import cosine_similarity
8+
from common.core.config import settings
9+
from common.core.deps import SessionDep, CurrentUser
10+
from common.utils.utils import SQLBotLogUtil
11+
12+
13+
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str):
14+
_list = []
15+
for table_schema in tables:
16+
_list.append({"table_schema": table_schema, "cosine_similarity": 0.0})
17+
18+
if _list:
19+
try:
20+
text = [s.get('table_schema') for s in _list]
21+
22+
model = EmbeddingModelCache.get_model()
23+
results = model.embed_documents(text)
24+
25+
q_embedding = model.embed_query(question)
26+
for index in range(len(results)):
27+
item = results[index]
28+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
29+
30+
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
31+
_list = _list[:settings.TABLE_EMBEDDING_COUNT]
32+
# print(len(_list))
33+
SQLBotLogUtil.info(json.dumps(_list))
34+
return [t.get("table_schema") for t in _list]
35+
except Exception:
36+
traceback.print_exc()
37+
return _list
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Author: Junjun
2+
# Date: 2025/9/23
3+
import math
4+
5+
6+
def cosine_similarity(vec_a, vec_b):
7+
if len(vec_a) != len(vec_b):
8+
raise ValueError("The vector dimension must be the same")
9+
10+
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
11+
12+
norm_a = math.sqrt(sum(a * a for a in vec_a))
13+
norm_b = math.sqrt(sum(b * b for b in vec_b))
14+
15+
if norm_a == 0 or norm_b == 0:
16+
return 0.0
17+
18+
return dot_product / (norm_a * norm_b)

0 commit comments

Comments
 (0)