1515from apps .template .generate_chart .generator import get_base_terminology_template
1616from apps .terminology .models .terminology_model import Terminology , TerminologyInfo
1717from common .core .config import settings
18- from common .core .deps import SessionDep
18+ from common .core .deps import SessionDep , Trans
1919
2020executor = ThreadPoolExecutor (max_workers = 200 )
2121
2222
23- def page_terminology (session : SessionDep , current_page : int = 1 , page_size : int = 10 , name : Optional [str ] = None ):
23+ def page_terminology (session : SessionDep , current_page : int = 1 , page_size : int = 10 , name : Optional [str ] = None ,
24+ oid : Optional [int ] = 1 ):
2425 _list : List [TerminologyInfo ] = []
2526
2627 child = aliased (Terminology )
@@ -91,7 +92,7 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
9192 children_subquery ,
9293 Terminology .id == children_subquery .c .pid
9394 )
94- .where (Terminology .id .in_ (paginated_parent_ids ))
95+ .where (and_ ( Terminology .id .in_ (paginated_parent_ids ), Terminology . oid == oid ))
9596 .order_by (Terminology .create_time .desc ())
9697 )
9798 else :
@@ -120,7 +121,7 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
120121 func .jsonb_agg (child .word ).filter (child .word .isnot (None )).label ('other_words' )
121122 )
122123 .outerjoin (child , and_ (Terminology .id == child .pid ))
123- .where (Terminology .id .in_ (paginated_parent_ids ))
124+ .where (and_ ( Terminology .id .in_ (paginated_parent_ids ), Terminology . oid == oid ))
124125 .group_by (Terminology .id , Terminology .word )
125126 .order_by (Terminology .create_time .desc ())
126127 )
@@ -139,9 +140,9 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
139140 return current_page , page_size , total_count , total_pages , _list
140141
141142
142- def create_terminology (session : SessionDep , info : TerminologyInfo ):
143+ def create_terminology (session : SessionDep , info : TerminologyInfo , oid : int ):
143144 create_time = datetime .datetime .now ()
144- parent = Terminology (word = info .word , create_time = create_time , description = info .description )
145+ parent = Terminology (word = info .word , create_time = create_time , description = info .description , oid = oid )
145146
146147 result = Terminology (** parent .model_dump ())
147148
@@ -158,7 +159,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo):
158159 if other_word .strip () == "" :
159160 continue
160161 _list .append (
161- Terminology (pid = result .id , word = other_word , create_time = create_time ))
162+ Terminology (pid = result .id , word = other_word , create_time = create_time , oid = oid ))
162163 session .bulk_save_objects (_list )
163164 session .flush ()
164165 session .commit ()
@@ -169,7 +170,14 @@ def create_terminology(session: SessionDep, info: TerminologyInfo):
169170 return result .id
170171
171172
172- def update_terminology (session : SessionDep , info : TerminologyInfo ):
173+ def update_terminology (session : SessionDep , info : TerminologyInfo , oid : int , trans : Trans ):
174+ count = session .query (Terminology ).filter (
175+ Terminology .oid == oid ,
176+ Terminology .id == info .id
177+ ).count ()
178+ if count == 0 :
179+ raise Exception (trans ('i18n_terminology.terminology_not_exists' ))
180+
173181 stmt = update (Terminology ).where (and_ (Terminology .id == info .id )).values (
174182 word = info .word ,
175183 description = info .description ,
@@ -188,14 +196,15 @@ def update_terminology(session: SessionDep, info: TerminologyInfo):
188196 if other_word .strip () == "" :
189197 continue
190198 _list .append (
191- Terminology (pid = info .id , word = other_word , create_time = create_time ))
199+ Terminology (pid = info .id , word = other_word , create_time = create_time , oid = oid ))
192200 session .bulk_save_objects (_list )
193201 session .flush ()
194202 session .commit ()
195203
196204 # embedding
197205 run_save_embeddings ([info .id ])
198206
207+
199208 return info .id
200209
201210
@@ -256,23 +265,19 @@ def save_embeddings(ids: List[int]):
256265
257266
258267embedding_sql = f"""
259- SELECT id, pid, word, description, similarity
268+ SELECT id, pid, word, similarity
260269FROM
261- (SELECT id, pid, word,
262- COALESCE(
263- description,
264- (SELECT description FROM terminology AS parent WHERE parent.id = child.pid)
265- ) AS description,
270+ (SELECT id, pid, word, oid,
266271( 1 - (embedding <=> :embedding_array) ) AS similarity
267272FROM terminology AS child
268273) TEMP
269- WHERE similarity > { settings .EMBEDDING_SIMILARITY }
274+ WHERE similarity > { settings .EMBEDDING_SIMILARITY } and oid = :oid
270275ORDER BY similarity DESC
271276LIMIT { settings .EMBEDDING_TOP_COUNT }
272277"""
273278
274279
275- def select_terminology_by_word (session : SessionDep , word : str ):
280+ def select_terminology_by_word (session : SessionDep , word : str , oid : int ):
276281 if word .strip () == "" :
277282 return []
278283
@@ -283,48 +288,48 @@ def select_terminology_by_word(session: SessionDep, word: str):
283288 Terminology .id ,
284289 Terminology .pid ,
285290 Terminology .word ,
286- func .coalesce (
287- Terminology .description ,
288- select (Terminology .description )
289- .where (and_ (Terminology .id == Terminology .pid ))
290- .scalar_subquery ()
291- ).label ('description' )
292291 )
293292 .where (
294- text (":sentence ILIKE '%' || word || '%'" )
293+ and_ ( text (":sentence ILIKE '%' || word || '%'" ), Terminology . oid == oid )
295294 )
296295 )
297296
298297 results = session .execute (stmt , {'sentence' : word }).fetchall ()
299298
300299 for row in results :
301- _list .append (Terminology (id = row .id , word = row .word , pid = row .pid , description = row . description ))
300+ _list .append (Terminology (id = row .id , word = row .word , pid = row .pid ))
302301
303302 if settings .EMBEDDING_ENABLED :
304303 try :
305304 model = EmbeddingModelCache .get_model ()
306305
307306 embedding = model .embed_query (word )
308307
309- print (embedding_sql )
310- results = session .execute (text (embedding_sql ), {'embedding_array' : str (embedding )})
308+ results = session .execute (text (embedding_sql ), {'embedding_array' : str (embedding ), 'oid' : oid })
311309
312310 for row in results :
313- _list .append (Terminology (id = row .id , word = row .word , pid = row .pid , description = row . description ))
311+ _list .append (Terminology (id = row .id , word = row .word , pid = row .pid ))
314312
315313 except Exception :
316314 traceback .print_exc ()
317315
318316 _map : dict = {}
319- _ids : set [int ] = set ()
317+ _ids : list [int ] = []
320318 for row in _list :
321- if row .id in _ids :
319+ if row .id in _ids or row . pid in _ids :
322320 continue
323- _ids .add (row .id )
324- if row .pid :
325- pid = str (row .pid )
321+ if row .pid is not None :
322+ _ids .append (row .pid )
326323 else :
327- pid = str (row .id )
324+ _ids .append (row .id )
325+
326+ if len (_ids ) == 0 :
327+ return []
328+
329+ t_list = session .query (Terminology .id , Terminology .pid , Terminology .word , Terminology .description ).filter (
330+ or_ (Terminology .id .in_ (_ids ), Terminology .pid .in_ (_ids ))).all ()
331+ for row in t_list :
332+ pid = str (row .pid ) if row .pid is not None else str (row .id )
328333 if _map .get (pid ) is None :
329334 _map [pid ] = {'words' : [], 'description' : row .description }
330335 _map [pid ]['words' ].append (row .word )
@@ -350,6 +355,7 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
350355 item_name_func = lambda x : 'terminology' if x == 'terminologies' else 'word' if x == 'words' else 'item'
351356 dicttoxml .LOG .setLevel (logging .ERROR )
352357 xml = dicttoxml .dicttoxml (_dict ,
358+ cdata = ['word' , 'description' ],
353359 custom_root = root ,
354360 item_func = item_name_func ,
355361 xml_declaration = False ,
@@ -361,11 +367,22 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
361367 end_index = pretty_xml .find ('>' ) + 1
362368 pretty_xml = pretty_xml [end_index :].lstrip ()
363369
370+ # 替换所有 XML 转义字符
371+ escape_map = {
372+ '<' : '<' ,
373+ '>' : '>' ,
374+ '&' : '&' ,
375+ '"' : '"' ,
376+ ''' : "'"
377+ }
378+ for escaped , original in escape_map .items ():
379+ pretty_xml = pretty_xml .replace (escaped , original )
380+
364381 return pretty_xml
365382
366383
367- def get_terminology_template (session : SessionDep , question : str ) -> str :
368- _results = select_terminology_by_word (session , question )
384+ def get_terminology_template (session : SessionDep , question : str , oid : int ) -> str :
385+ _results = select_terminology_by_word (session , question , oid )
369386 if _results and len (_results ) > 0 :
370387 terminology = to_xml_string (_results )
371388 template = get_base_terminology_template ().format (terminologies = terminology )
0 commit comments