|
1 | 1 | import datetime |
2 | 2 | import json |
| 3 | +import traceback |
3 | 4 | from typing import List, Optional |
4 | 5 |
|
5 | 6 | from fastapi import HTTPException |
6 | 7 | from sqlalchemy import and_, text |
7 | 8 | from sqlmodel import select |
8 | 9 |
|
| 10 | +from apps.ai_model.embedding import EmbeddingModelCache |
9 | 11 | from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user |
| 12 | +from apps.datasource.embedding.ds_embedding import cosine_similarity |
10 | 13 | from apps.datasource.utils.utils import aes_decrypt |
11 | 14 | from apps.db.constant import DB |
12 | 15 | from apps.db.db import get_tables, get_fields, exec_sql, check_connection |
13 | 16 | from apps.db.engine import get_engine_config, get_engine_conn |
| 17 | +from common.core.config import settings |
14 | 18 | from common.core.deps import SessionDep, CurrentUser, Trans |
| 19 | +from common.utils.utils import SQLBotLogUtil |
15 | 20 | from common.utils.utils import deepcopy_ignore_extra |
16 | 21 | from .table import get_tables_by_ds_id |
17 | 22 | from ..crud.field import delete_field_by_ds_id, update_field |
18 | 23 | from ..crud.table import delete_table_by_ds_id, update_table |
19 | | -from ..embedding.ds_embedding import get_table_embedding |
20 | 24 | from ..models.datasource import CoreDatasource, CreateDatasource, CoreTable, CoreField, ColumnSchema, TableObj, \ |
21 | 25 | DatasourceConf, TableAndFields |
22 | 26 |
|
@@ -385,3 +389,30 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat |
385 | 389 | for s in tables: |
386 | 390 | schema_str += s |
387 | 391 | return schema_str |
| 392 | + |
| 393 | + |
| 394 | +def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[str], question: str): |
| 395 | + _list = [] |
| 396 | + for table_schema in tables: |
| 397 | + _list.append({"table_schema": table_schema, "cosine_similarity": 0.0}) |
| 398 | + |
| 399 | + if _list: |
| 400 | + try: |
| 401 | + text = [s.get('table_schema') for s in _list] |
| 402 | + |
| 403 | + model = EmbeddingModelCache.get_model() |
| 404 | + results = model.embed_documents(text) |
| 405 | + |
| 406 | + q_embedding = model.embed_query(question) |
| 407 | + for index in range(len(results)): |
| 408 | + item = results[index] |
| 409 | + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) |
| 410 | + |
| 411 | + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) |
| 412 | + _list = _list[:settings.TABLE_EMBEDDING_COUNT] |
| 413 | + # print(len(_list)) |
| 414 | + SQLBotLogUtil.info(json.dumps(_list)) |
| 415 | + return [t.get("table_schema") for t in _list] |
| 416 | + except Exception: |
| 417 | + traceback.print_exc() |
| 418 | + return _list |
0 commit comments