99from sqlalchemy import text
1010
1111from apps .ai_model .embedding import EmbeddingModelCache
12- from apps .data_training .models .data_training_model import DataTrainingInfo , DataTraining
12+ from apps .data_training .models .data_training_model import DataTrainingInfo , DataTraining , DataTrainingInfoResult
1313from apps .datasource .models .datasource import CoreDatasource
14+ from apps .system .models .system_model import AssistantModel
1415from apps .template .generate_chart .generator import get_base_data_training_template
1516from common .core .config import settings
1617from common .core .deps import SessionDep , Trans
1920
2021def page_data_training (session : SessionDep , current_page : int = 1 , page_size : int = 10 , name : Optional [str ] = None ,
2122 oid : Optional [int ] = 1 ):
22- _list : List [DataTrainingInfo ] = []
23+ _list : List [DataTrainingInfoResult ] = []
2324
2425 current_page = max (1 , current_page )
2526 page_size = max (10 , page_size )
@@ -63,40 +64,60 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
6364 DataTraining .create_time ,
6465 DataTraining .description ,
6566 DataTraining .enabled ,
67+ DataTraining .advanced_application ,
68+ AssistantModel .name .label ('advanced_application_name' ),
6669 )
6770 .outerjoin (CoreDatasource , and_ (DataTraining .datasource == CoreDatasource .id ))
71+ .outerjoin (AssistantModel ,
72+ and_ (DataTraining .advanced_application == AssistantModel .id , AssistantModel .type == 1 ))
6873 .where (and_ (DataTraining .id .in_ (paginated_parent_ids )))
6974 .order_by (DataTraining .create_time .desc ())
7075 )
7176
7277 result = session .execute (stmt )
7378
7479 for row in result :
75- _list .append (DataTrainingInfo (
76- id = row .id ,
77- oid = row .oid ,
80+ _list .append (DataTrainingInfoResult (
81+ id = str ( row .id ) ,
82+ oid = str ( row .oid ) ,
7883 datasource = row .datasource ,
7984 datasource_name = row .name ,
8085 question = row .question ,
8186 create_time = row .create_time ,
8287 description = row .description ,
8388 enabled = row .enabled ,
89+ advanced_application = str (row .advanced_application ) if row .advanced_application else None ,
90+ advanced_application_name = row .advanced_application_name ,
8491 ))
8592
8693 return current_page , page_size , total_count , total_pages , _list
8794
8895
8996def create_training (session : SessionDep , info : DataTrainingInfo , oid : int , trans : Trans ):
9097 create_time = datetime .datetime .now ()
91- if info .datasource is None :
92- raise Exception (trans ("i18n_data_training.datasource_cannot_be_none" ))
98+ if info .datasource is None and info .advanced_application is None :
99+ if oid == 1 :
100+ raise Exception (trans ("i18n_data_training.datasource_assistant_cannot_be_none" ))
101+ else :
102+ raise Exception (trans ("i18n_data_training.datasource_cannot_be_none" ))
103+
93104 parent = DataTraining (question = info .question , create_time = create_time , description = info .description , oid = oid ,
94- datasource = info .datasource , enabled = info .enabled )
105+ datasource = info .datasource , enabled = info .enabled ,
106+ advanced_application = info .advanced_application )
107+
108+ stmt = select (DataTraining .id ).where (and_ (DataTraining .question == info .question , DataTraining .oid == oid ))
109+
110+ if info .datasource is not None and info .advanced_application is not None :
111+ stmt = stmt .where (
112+ or_ (DataTraining .datasource == info .datasource ,
113+ DataTraining .advanced_application == info .advanced_application ))
114+ elif info .datasource is not None and info .advanced_application is None :
115+ stmt = stmt .where (and_ (DataTraining .datasource == info .datasource ))
116+ elif info .datasource is None and info .advanced_application is not None :
117+ stmt = stmt .where (and_ (DataTraining .advanced_application == info .advanced_application ))
118+
119+ exists = session .query (stmt .exists ()).scalar ()
95120
96- exists = session .query (
97- session .query (DataTraining ).filter (
98- and_ (DataTraining .question == info .question , DataTraining .oid == oid ,
99- DataTraining .datasource == info .datasource )).exists ()).scalar ()
100121 if exists :
101122 raise Exception (trans ("i18n_data_training.exists_in_db" ))
102123
@@ -116,20 +137,32 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
116137
117138
118139def update_training (session : SessionDep , info : DataTrainingInfo , oid : int , trans : Trans ):
119- if info .datasource is None :
120- raise Exception (trans ("i18n_data_training.datasource_cannot_be_none" ))
140+ if info .datasource is None and info .advanced_application is None :
141+ if oid == 1 :
142+ raise Exception (trans ("i18n_data_training.datasource_assistant_cannot_be_none" ))
143+ else :
144+ raise Exception (trans ("i18n_data_training.datasource_cannot_be_none" ))
121145
122146 count = session .query (DataTraining ).filter (
123147 DataTraining .id == info .id
124148 ).count ()
125149 if count == 0 :
126150 raise Exception (trans ('i18n_data_training.data_training_not_exists' ))
127151
128- exists = session .query (
129- session .query (DataTraining ).filter (
130- and_ (DataTraining .question == info .question , DataTraining .oid == oid ,
131- DataTraining .datasource == info .datasource ,
132- DataTraining .id != info .id )).exists ()).scalar ()
152+ stmt = select (DataTraining .id ).where (
153+ and_ (DataTraining .question == info .question , DataTraining .oid == oid , DataTraining .id != info .id ))
154+
155+ if info .datasource is not None and info .advanced_application is not None :
156+ stmt = stmt .where (
157+ or_ (DataTraining .datasource == info .datasource ,
158+ DataTraining .advanced_application == info .advanced_application ))
159+ elif info .datasource is not None and info .advanced_application is None :
160+ stmt = stmt .where (and_ (DataTraining .datasource == info .datasource ))
161+ elif info .datasource is None and info .advanced_application is not None :
162+ stmt = stmt .where (and_ (DataTraining .advanced_application == info .advanced_application ))
163+
164+ exists = session .query (stmt .exists ()).scalar ()
165+
133166 if exists :
134167 raise Exception (trans ("i18n_data_training.exists_in_db" ))
135168
@@ -138,6 +171,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
138171 description = info .description ,
139172 datasource = info .datasource ,
140173 enabled = info .enabled ,
174+ advanced_application = info .advanced_application ,
141175 )
142176 session .execute (stmt )
143177 session .commit ()
@@ -231,9 +265,21 @@ def save_embeddings(session_maker, ids: List[int]):
231265ORDER BY similarity DESC
232266LIMIT { settings .EMBEDDING_DATA_TRAINING_TOP_COUNT }
233267"""
268+ embedding_sql_in_advanced_application = f"""
269+ SELECT id, datasource, question, similarity
270+ FROM
271+ (SELECT id, datasource, question, oid, enabled,
272+ ( 1 - (embedding <=> :embedding_array) ) AS similarity
273+ FROM data_training AS child
274+ ) TEMP
275+ WHERE similarity > { settings .EMBEDDING_DATA_TRAINING_SIMILARITY } and oid = :oid and advanced_application = :advanced_application and enabled = true
276+ ORDER BY similarity DESC
277+ LIMIT { settings .EMBEDDING_DATA_TRAINING_TOP_COUNT }
278+ """
234279
235280
236- def select_training_by_question (session : SessionDep , question : str , oid : int , datasource : int ):
281+ def select_training_by_question (session : SessionDep , question : str , oid : int , datasource : Optional [int ] = None ,
282+ advanced_application_id : Optional [int ] = None ):
237283 if question .strip () == "" :
238284 return []
239285
@@ -248,10 +294,13 @@ def select_training_by_question(session: SessionDep, question: str, oid: int, da
248294 .where (
249295 and_ (or_ (text (":sentence ILIKE '%' || question || '%'" ), text ("question ILIKE '%' || :sentence || '%'" )),
250296 DataTraining .oid == oid ,
251- DataTraining .datasource == datasource ,
252- DataTraining .enabled == True ,)
297+ DataTraining .enabled == True )
253298 )
254299 )
300+ if advanced_application_id is not None :
301+ stmt = stmt .where (and_ (DataTraining .advanced_application == advanced_application_id ))
302+ else :
303+ stmt = stmt .where (and_ (DataTraining .datasource == datasource ))
255304
256305 results = session .execute (stmt , {'sentence' : question }).fetchall ()
257306
@@ -264,8 +313,13 @@ def select_training_by_question(session: SessionDep, question: str, oid: int, da
264313
265314 embedding = model .embed_query (question )
266315
267- results = session .execute (text (embedding_sql ),
268- {'embedding_array' : str (embedding ), 'oid' : oid , 'datasource' : datasource })
316+ if advanced_application_id is not None :
317+ results = session .execute (text (embedding_sql_in_advanced_application ),
318+ {'embedding_array' : str (embedding ), 'oid' : oid ,
319+ 'advanced_application' : advanced_application_id })
320+ else :
321+ results = session .execute (text (embedding_sql ),
322+ {'embedding_array' : str (embedding ), 'oid' : oid , 'datasource' : datasource })
269323
270324 for row in results :
271325 _list .append (DataTraining (id = row .id , question = row .question ))
@@ -328,12 +382,13 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'sql-examples') -> str:
328382 return pretty_xml
329383
330384
331- def get_training_template (session : SessionDep , question : str , datasource : int , oid : Optional [int ] = 1 ) -> str :
385+ def get_training_template (session : SessionDep , question : str , oid : Optional [int ] = 1 , datasource : Optional [int ] = None ,
386+ advanced_application_id : Optional [int ] = None ) -> str :
332387 if not oid :
333388 oid = 1
334- if not datasource :
389+ if not datasource and not advanced_application_id :
335390 return ''
336- _results = select_training_by_question (session , question , oid , datasource )
391+ _results = select_training_by_question (session , question , oid , datasource , advanced_application_id )
337392 if _results and len (_results ) > 0 :
338393 data_training = to_xml_string (_results )
339394 template = get_base_data_training_template ().format (data_training = data_training )
0 commit comments