@@ -48,18 +48,18 @@ def save_table_embedding(session: SessionDep, ids: List[int]):
4848 if not ids or len (ids ) == 0 :
4949 return
5050 try :
51-
52- _list = session . query ( CoreTable ). filter ( and_ ( CoreTable . id . in_ ( ids ))). all ()
53-
54- table_schema = []
55- for item in _list :
56- fields = session .query (CoreField ).filter (CoreField .table_id == item .id ).all ()
51+ SQLBotLogUtil . info ( 'start table embedding' )
52+ start_time = time . time ()
53+ model = EmbeddingModelCache . get_model ()
54+ for id in ids :
55+ table = session . query ( CoreTable ). filter ( CoreTable . id == id ). first ()
56+ fields = session .query (CoreField ).filter (CoreField .table_id == table .id ).all ()
5757
5858 schema_table = ''
59- schema_table += f"# Table: { item .table_name } "
59+ schema_table += f"# Table: { table .table_name } "
6060 table_comment = ''
61- if item .custom_comment :
62- table_comment = item .custom_comment .strip ()
61+ if table .custom_comment :
62+ table_comment = table .custom_comment .strip ()
6363 if table_comment == '' :
6464 schema_table += '\n [\n '
6565 else :
@@ -77,22 +77,14 @@ def save_table_embedding(session: SessionDep, ids: List[int]):
7777 field_list .append (f"({ field .field_name } :{ field .field_type } , { field_comment } )" )
7878 schema_table += ",\n " .join (field_list )
7979 schema_table += '\n ]\n '
80- table_schema .append (schema_table )
81-
82- model = EmbeddingModelCache .get_model ()
80+ # table_schema.append(schema_table)
81+ emb = model .embed_query (schema_table )
8382
84- SQLBotLogUtil .info (json .dumps (table_schema ))
85- SQLBotLogUtil .info ('start table embedding' )
86- start_time = time .time ()
87- results = model .embed_documents (table_schema )
88- end_time = time .time ()
89- SQLBotLogUtil .info ('table embedding finished in:' + str (end_time - start_time ) + 'seconds' )
90-
91- for index in range (len (results )):
92- item = results [index ]
93- stmt = update (CoreTable ).where (and_ (CoreTable .id == _list [index ].id )).values (embedding = item )
83+ stmt = update (CoreTable ).where (and_ (CoreTable .id == id )).values (embedding = emb )
9484 session .execute (stmt )
9585 session .commit ()
9686
87+ end_time = time .time ()
88+ SQLBotLogUtil .info ('table embedding finished in:' + str (end_time - start_time ) + 'seconds' )
9789 except Exception :
9890 traceback .print_exc ()
0 commit comments