1818from common .utils .embedding_threads import run_save_data_training_embeddings
1919
2020
21- def page_data_training (session : SessionDep , current_page : int = 1 , page_size : int = 10 , name : Optional [str ] = None ,
22- oid : Optional [int ] = 1 ):
23- _list : List [DataTrainingInfoResult ] = []
24-
25- current_page = max (1 , current_page )
26- page_size = max (10 , page_size )
27-
28- total_count = 0
29- total_pages = 0
30-
21+ def get_data_training_base_query (oid : int , name : Optional [str ] = None ):
22+ """
23+ 获取数据训练查询的基础查询结构
24+ """
3125 if name and name .strip () != "" :
3226 keyword_pattern = f"%{ name .strip ()} %"
3327 parent_ids_subquery = (
3428 select (DataTraining .id )
35- .where (and_ (DataTraining .question .ilike (keyword_pattern ), DataTraining .oid == oid )) # LIKE查询条件
29+ .where (and_ (DataTraining .question .ilike (keyword_pattern ), DataTraining .oid == oid ))
3630 )
3731 else :
3832 parent_ids_subquery = (
3933 select (DataTraining .id ).where (and_ (DataTraining .oid == oid ))
4034 )
4135
36+ return parent_ids_subquery
37+
38+
39+ def build_data_training_query (session : SessionDep , oid : int , name : Optional [str ] = None ,
40+ paginate : bool = True , current_page : int = 1 , page_size : int = 10 ):
41+ """
42+ 构建数据训练查询的通用方法
43+ """
44+ parent_ids_subquery = get_data_training_base_query (oid , name )
45+
46+ # 计算总数
4247 count_stmt = select (func .count ()).select_from (parent_ids_subquery .subquery ())
4348 total_count = session .execute (count_stmt ).scalar ()
44- total_pages = (total_count + page_size - 1 ) // page_size
4549
46- if current_page > total_pages :
50+ if paginate :
51+ # 分页处理
52+ page_size = max (10 , page_size )
53+ total_pages = (total_count + page_size - 1 ) // page_size
54+ current_page = max (1 , min (current_page , total_pages )) if total_pages > 0 else 1
55+
56+ paginated_parent_ids = (
57+ parent_ids_subquery
58+ .order_by (DataTraining .create_time .desc ())
59+ .offset ((current_page - 1 ) * page_size )
60+ .limit (page_size )
61+ .subquery ()
62+ )
63+ else :
64+ # 不分页,获取所有数据
65+ total_pages = 1
4766 current_page = 1
67+ page_size = total_count if total_count > 0 else 1
4868
49- paginated_parent_ids = (
50- parent_ids_subquery
51- .order_by (DataTraining .create_time .desc ())
52- .offset ((current_page - 1 ) * page_size )
53- .limit (page_size )
54- .subquery ()
55- )
69+ paginated_parent_ids = (
70+ parent_ids_subquery
71+ .order_by (DataTraining .create_time .desc ())
72+ .subquery ()
73+ )
5674
75+ # 构建主查询
5776 stmt = (
5877 select (
5978 DataTraining .id ,
@@ -74,6 +93,14 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
7493 .order_by (DataTraining .create_time .desc ())
7594 )
7695
96+ return stmt , total_count , total_pages , current_page , page_size
97+
98+
99+ def execute_data_training_query (session : SessionDep , stmt ) -> List [DataTrainingInfoResult ]:
100+ """
101+ 执行查询并返回数据训练信息列表
102+ """
103+ _list = []
77104 result = session .execute (stmt )
78105
79106 for row in result :
@@ -90,9 +117,34 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
90117 advanced_application_name = row .advanced_application_name ,
91118 ))
92119
120+ return _list
121+
122+
123+ def page_data_training (session : SessionDep , current_page : int = 1 , page_size : int = 10 ,
124+ name : Optional [str ] = None , oid : Optional [int ] = 1 ):
125+ """
126+ 分页查询数据训练(原方法保持不变)
127+ """
128+ stmt , total_count , total_pages , current_page , page_size = build_data_training_query (
129+ session , oid , name , True , current_page , page_size
130+ )
131+ _list = execute_data_training_query (session , stmt )
132+
93133 return current_page , page_size , total_count , total_pages , _list
94134
95135
136+ def get_all_data_training (session : SessionDep , name : Optional [str ] = None , oid : Optional [int ] = 1 ):
137+ """
138+ 获取所有数据训练(不分页)
139+ """
140+ stmt , total_count , total_pages , current_page , page_size = build_data_training_query (
141+ session , oid , name , False
142+ )
143+ _list = execute_data_training_query (session , stmt )
144+
145+ return _list
146+
147+
96148def create_training (session : SessionDep , info : DataTrainingInfo , oid : int , trans : Trans ):
97149 create_time = datetime .datetime .now ()
98150 if info .datasource is None and info .advanced_application is None :
0 commit comments