11# Author: Junjun
22# Date: 2025/9/18
33import json
4+ import time
45import traceback
56from typing import Optional
67
78from apps .ai_model .embedding import EmbeddingModelCache
8- from apps .datasource .crud .datasource import get_table_schema
99from apps .datasource .embedding .utils import cosine_similarity
1010from apps .datasource .models .datasource import CoreDatasource
1111from apps .system .crud .assistant import AssistantOutDs
@@ -18,42 +18,71 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
1818 question : str ,
1919 current_assistant : Optional [CurrentAssistant ] = None ):
2020 _list = []
21- if current_assistant and current_assistant .type != 4 :
21+ if current_assistant and current_assistant .type == 1 :
2222 if out_ds .ds_list :
2323 for _ds in out_ds .ds_list :
2424 ds = out_ds .get_ds (_ds .id )
2525 table_schema = out_ds .get_db_schema (_ds .id , question , embedding = False )
2626 ds_info = f"{ ds .name } , { ds .description } \n "
2727 ds_schema = ds_info + table_schema
2828 _list .append ({"id" : ds .id , "ds_schema" : ds_schema , "cosine_similarity" : 0.0 , "ds" : ds })
29+
30+ if _list :
31+ try :
32+ text = [s .get ('ds_schema' ) for s in _list ]
33+
34+ model = EmbeddingModelCache .get_model ()
35+ results = model .embed_documents (text )
36+
37+ q_embedding = model .embed_query (question )
38+ for index in range (len (results )):
39+ item = results [index ]
40+ _list [index ]['cosine_similarity' ] = cosine_similarity (q_embedding , item )
41+
42+ _list .sort (key = lambda x : x ['cosine_similarity' ], reverse = True )
43+ # print(len(_list))
44+ SQLBotLogUtil .info (json .dumps (
45+ [{"id" : ele .get ("id" ), "name" : ele .get ("ds" ).name ,
46+ "cosine_similarity" : ele .get ("cosine_similarity" )}
47+ for ele in _list ]))
48+ ds = _list [0 ].get ('ds' )
49+ return {"id" : ds .id , "name" : ds .name , "description" : ds .description }
50+ except Exception :
51+ traceback .print_exc ()
2952 else :
3053 for _ds in _ds_list :
3154 if _ds .get ('id' ):
3255 ds = session .get (CoreDatasource , _ds .get ('id' ))
33- table_schema = get_table_schema (session , current_user , ds , question , embedding = False )
34- ds_info = f"{ ds .name } , { ds .description } \n "
35- ds_schema = ds_info + table_schema
36- _list .append ({"id" : ds .id , "ds_schema" : ds_schema , "cosine_similarity" : 0.0 , "ds" : ds })
56+ # table_schema = get_table_schema(session, current_user, ds, question, embedding=False)
57+ # ds_info = f"{ds.name}, {ds.description}\n"
58+ # ds_schema = ds_info + table_schema
59+ _list .append ({"id" : ds .id , "cosine_similarity" : 0.0 , "ds" : ds , "embedding" : ds .embedding })
60+
61+ if _list :
62+ try :
63+ # text = [s.get('ds_schema') for s in _list]
64+
65+ model = EmbeddingModelCache .get_model ()
66+ start_time = time .time ()
67+ # results = model.embed_documents(text)
68+ results = [item .get ('embedding' ) for item in _list ]
69+
70+ q_embedding = model .embed_query (question )
71+ for index in range (len (results )):
72+ item = results [index ]
73+ if item :
74+ _list [index ]['cosine_similarity' ] = cosine_similarity (q_embedding , item )
3775
38- if _list :
39- try :
40- text = [s .get ('ds_schema' ) for s in _list ]
41-
42- model = EmbeddingModelCache .get_model ()
43- results = model .embed_documents (text )
44-
45- q_embedding = model .embed_query (question )
46- for index in range (len (results )):
47- item = results [index ]
48- _list [index ]['cosine_similarity' ] = cosine_similarity (q_embedding , item )
49-
50- _list .sort (key = lambda x : x ['cosine_similarity' ], reverse = True )
51- # print(len(_list))
52- SQLBotLogUtil .info (json .dumps (
53- [{"id" : ele .get ("id" ), "name" : ele .get ("ds" ).name , "cosine_similarity" : ele .get ("cosine_similarity" )}
54- for ele in _list ]))
55- ds = _list [0 ].get ('ds' )
56- return {"id" : ds .id , "name" : ds .name , "description" : ds .description }
57- except Exception :
58- traceback .print_exc ()
76+ _list .sort (key = lambda x : x ['cosine_similarity' ], reverse = True )
77+ # print(len(_list))
78+ end_time = time .time ()
79+ SQLBotLogUtil .info (str (end_time - start_time ))
80+ SQLBotLogUtil .info (json .dumps (
81+ [{"id" : ele .get ("id" ), "name" : ele .get ("ds" ).name ,
82+ "cosine_similarity" : ele .get ("cosine_similarity" )}
83+ for ele in _list ]))
84+ ds = _list [0 ].get ('ds' )
85+ return {"id" : ds .id , "name" : ds .name , "description" : ds .description }
86+ except Exception :
87+ traceback .print_exc ()
5988 return _list
0 commit comments