11import datetime
2+ import logging
3+ import traceback
4+ from concurrent .futures import ThreadPoolExecutor
25from typing import List , Optional
6+ from xml .dom .minidom import parseString
37
4- from sqlalchemy import and_ , or_ , select , func , delete , update
8+ import dicttoxml
9+ from sqlalchemy import and_ , or_ , select , func , delete , update , union
10+ from sqlalchemy import create_engine , text
511from sqlalchemy .orm import aliased
12+ from sqlalchemy .orm import sessionmaker
613
14+ from apps .ai_model .embedding import EmbeddingModelCache
15+ from apps .template .generate_chart .generator import get_base_terminology_template
716from apps .terminology .models .terminology_model import Terminology , TerminologyInfo
17+ from common .core .config import settings
818from common .core .deps import SessionDep
919
20+ executor = ThreadPoolExecutor (max_workers = 200 )
21+
1022
1123def page_terminology (session : SessionDep , current_page : int = 1 , page_size : int = 10 , name : Optional [str ] = None ):
1224 _list : List [TerminologyInfo ] = []
@@ -24,7 +36,7 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
2436 # 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点)
2537 matched_ids_subquery = (
2638 select (Terminology .id )
27- .where (Terminology .word .like (keyword_pattern )) # LIKE查询条件
39+ .where (Terminology .word .ilike (keyword_pattern )) # LIKE查询条件
2840 .subquery ()
2941 )
3042
@@ -82,7 +94,6 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
8294 .where (Terminology .id .in_ (paginated_parent_ids ))
8395 .order_by (Terminology .create_time .desc ())
8496 )
85- print (str (stmt ))
8697 else :
8798 parent_ids_subquery = (
8899 select (Terminology .id )
@@ -113,7 +124,6 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
113124 .group_by (Terminology .id , Terminology .word )
114125 .order_by (Terminology .create_time .desc ())
115126 )
116- print (str (stmt ))
117127
118128 result = session .execute (stmt )
119129
@@ -145,13 +155,16 @@ def create_terminology(session: SessionDep, info: TerminologyInfo):
145155 _list : List [Terminology ] = []
146156 if info .other_words :
147157 for other_word in info .other_words :
158+ if other_word .strip () == "" :
159+ continue
148160 _list .append (
149161 Terminology (pid = result .id , word = other_word , create_time = create_time ))
150162 session .bulk_save_objects (_list )
151163 session .flush ()
152164 session .commit ()
153165
154- # todo embedding
166+ # embedding
167+ run_save_embeddings ([result .id ])
155168
156169 return result .id
157170
@@ -172,13 +185,16 @@ def update_terminology(session: SessionDep, info: TerminologyInfo):
172185 _list : List [Terminology ] = []
173186 if info .other_words :
174187 for other_word in info .other_words :
188+ if other_word .strip () == "" :
189+ continue
175190 _list .append (
176191 Terminology (pid = info .id , word = other_word , create_time = create_time ))
177192 session .bulk_save_objects (_list )
178193 session .flush ()
179194 session .commit ()
180195
181- # todo embedding
196+ # embedding
197+ run_save_embeddings ([info .id ])
182198
183199 return info .id
184200
@@ -187,3 +203,172 @@ def delete_terminology(session: SessionDep, ids: list[int]):
187203 stmt = delete (Terminology ).where (or_ (Terminology .id .in_ (ids ), Terminology .pid .in_ (ids )))
188204 session .execute (stmt )
189205 session .commit ()
206+
207+
208+ def run_save_embeddings (ids : List [int ]):
209+ executor .submit (save_embeddings , ids )
210+
211+
212+ def fill_empty_embeddings ():
213+ executor .submit (run_fill_empty_embeddings )
214+
215+
216+ def run_fill_empty_embeddings ():
217+ if not settings .EMBEDDING_ENABLED :
218+ return
219+ engine = create_engine (str (settings .SQLALCHEMY_DATABASE_URI ))
220+ session_maker = sessionmaker (bind = engine )
221+ session = session_maker ()
222+ stmt1 = select (Terminology .id ).where (and_ (Terminology .embedding .is_ (None ), Terminology .pid .is_ (None )))
223+ stmt2 = select (Terminology .pid ).where (and_ (Terminology .embedding .is_ (None ), Terminology .pid .isnot (None ))).distinct ()
224+ combined_stmt = union (stmt1 , stmt2 )
225+ results = session .execute (combined_stmt ).scalars ().all ()
226+ save_embeddings (results )
227+
228+
229+ def save_embeddings (ids : List [int ]):
230+ if not settings .EMBEDDING_ENABLED :
231+ return
232+
233+ if not ids or len (ids ) == 0 :
234+ return
235+ try :
236+ engine = create_engine (str (settings .SQLALCHEMY_DATABASE_URI ))
237+ session_maker = sessionmaker (bind = engine )
238+ session = session_maker ()
239+
240+ _list = session .query (Terminology ).filter (or_ (Terminology .id .in_ (ids ), Terminology .pid .in_ (ids ))).all ()
241+
242+ _words_list = [item .word for item in _list ]
243+
244+ model = EmbeddingModelCache .get_model ()
245+
246+ results = model .embed_documents (_words_list )
247+
248+ for index in range (len (results )):
249+ item = results [index ]
250+ stmt = update (Terminology ).where (and_ (Terminology .id == _list [index ].id )).values (embedding = item )
251+ session .execute (stmt )
252+ session .commit ()
253+
254+ except Exception :
255+ traceback .print_exc ()
256+
257+
258+ embedding_sql = f"""
259+ SELECT id, pid, word, description, similarity
260+ FROM
261+ (SELECT id, pid, word,
262+ COALESCE(
263+ description,
264+ (SELECT description FROM terminology AS parent WHERE parent.id = child.pid)
265+ ) AS description,
266+ ( 1 - (embedding <=> :embedding_array) ) AS similarity
267+ FROM terminology AS child
268+ ) TEMP
269+ WHERE similarity > { settings .EMBEDDING_SIMILARITY }
270+ ORDER BY similarity DESC
271+ LIMIT { settings .EMBEDDING_TOP_COUNT }
272+ """
273+
274+
275+ def select_terminology_by_word (session : SessionDep , word : str ):
276+ if word .strip () == "" :
277+ return []
278+
279+ _list : List [Terminology ] = []
280+
281+ stmt = (
282+ select (
283+ Terminology .id ,
284+ Terminology .pid ,
285+ Terminology .word ,
286+ func .coalesce (
287+ Terminology .description ,
288+ select (Terminology .description )
289+ .where (and_ (Terminology .id == Terminology .pid ))
290+ .scalar_subquery ()
291+ ).label ('description' )
292+ )
293+ .where (
294+ text (":sentence ILIKE '%' || word || '%'" )
295+ )
296+ )
297+
298+ results = session .execute (stmt , {'sentence' : word }).fetchall ()
299+
300+ for row in results :
301+ _list .append (Terminology (id = row .id , word = row .word , pid = row .pid , description = row .description ))
302+
303+ if settings .EMBEDDING_ENABLED :
304+ try :
305+ model = EmbeddingModelCache .get_model ()
306+
307+ embedding = model .embed_query (word )
308+
309+ print (embedding_sql )
310+ results = session .execute (text (embedding_sql ), {'embedding_array' : str (embedding )})
311+
312+ for row in results :
313+ _list .append (Terminology (id = row .id , word = row .word , pid = row .pid , description = row .description ))
314+
315+ except Exception :
316+ traceback .print_exc ()
317+
318+ _map : dict = {}
319+ _ids : set [int ] = set ()
320+ for row in _list :
321+ if row .id in _ids :
322+ continue
323+ _ids .add (row .id )
324+ if row .pid :
325+ pid = str (row .pid )
326+ else :
327+ pid = str (row .id )
328+ if _map .get (pid ) is None :
329+ _map [pid ] = {'words' : [], 'description' : row .description }
330+ _map [pid ]['words' ].append (row .word )
331+
332+ _results : list [dict ] = []
333+ for key in _map .keys ():
334+ _results .append (_map .get (key ))
335+
336+ return _results
337+
338+
339+ def get_example ():
340+ _obj = {
341+ 'terminologies' : [
342+ {'words' : ['GDP' , '国内生产总值' ],
343+ 'description' : '指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。' },
344+ ]
345+ }
346+ return to_xml_string (_obj , 'example' )
347+
348+
349+ def to_xml_string (_dict : list [dict ] | dict , root : str = 'terminologies' ) -> str :
350+ item_name_func = lambda x : 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item'
351+ dicttoxml .LOG .setLevel (logging .ERROR )
352+ xml = dicttoxml .dicttoxml (_dict ,
353+ custom_root = root ,
354+ item_func = item_name_func ,
355+ xml_declaration = False ,
356+ encoding = 'utf-8' ,
357+ attr_type = False ).decode ('utf-8' )
358+ pretty_xml = parseString (xml ).toprettyxml ()
359+
360+ if pretty_xml .startswith ('<?xml' ):
361+ end_index = pretty_xml .find ('>' ) + 1
362+ pretty_xml = pretty_xml [end_index :].lstrip ()
363+
364+ return pretty_xml
365+
366+
367+ def get_terminology_template (session : SessionDep , question : str ) -> str :
368+ _results = select_terminology_by_word (session , question )
369+ if _results and len (_results ) > 0 :
370+ terminology = to_xml_string (_results )
371+ template = get_base_terminology_template ().format (terminologies = terminology )
372+ return template
373+ else :
374+ return ''
0 commit comments