99import datetime
1010import logging
1111import os
12+ import threading
1213import traceback
1314from typing import List
1415
1516import django .db .models
17+ from django .db import models
1618from django .db .models import QuerySet
19+ from django .db .models .functions import Substr , Reverse
1720from langchain_core .embeddings import Embeddings
1821
1922from common .config .embedding_config import VectorStore
20- from common .db .search import native_search , get_dynamics_model
21- from common .event . common import embedding_poxy
23+ from common .db .search import native_search , get_dynamics_model , native_update
24+ from common .db . sql_execute import sql_execute , update_execute
2225from common .util .file_util import get_file_content
2326from common .util .lock import try_lock , un_lock
24- from dataset .models import Paragraph , Status , Document , ProblemParagraphMapping
27+ from common .util .page_utils import page
28+ from dataset .models import Paragraph , Status , Document , ProblemParagraphMapping , TaskType , State
2529from embedding .models import SourceType , SearchMode
2630from smartdoc .conf import PROJECT_DIR
2731
2832max_kb_error = logging .getLogger (__file__ )
2933max_kb = logging .getLogger (__file__ )
34+ lock = threading .Lock ()
3035
3136
3237class SyncWebDatasetArgs :
@@ -114,7 +119,8 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
114119 @param embedding_model: 向量模型
115120 """
116121 max_kb .info (f"开始--->向量化段落:{ paragraph_id } " )
117- status = Status .success
122+ # 更新到开始状态
123+ ListenerManagement .update_status (QuerySet (Paragraph ).filter (id = paragraph_id ), TaskType .EMBEDDING , State .STARTED )
118124 try :
119125 data_list = native_search (
120126 {'problem' : QuerySet (get_dynamics_model ({'paragraph.id' : django .db .models .CharField ()})).filter (
@@ -125,23 +131,114 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
125131 # 删除段落
126132 VectorStore .get_embedding_vector ().delete_by_paragraph_id (paragraph_id )
127133
128- def is_save_function ():
129- return QuerySet (Paragraph ).filter (id = paragraph_id ).exists ()
134+ def is_the_task_interrupted ():
135+ _paragraph = QuerySet (Paragraph ).filter (id = paragraph_id ).first ()
136+ if _paragraph is None or Status (_paragraph .status )[TaskType .EMBEDDING ] == State .REVOKE :
137+ return True
138+ return False
130139
131140 # 批量向量化
132- VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , is_save_function )
141+ VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , is_the_task_interrupted )
142+ # 更新到开始状态
143+ ListenerManagement .update_status (QuerySet (Paragraph ).filter (id = paragraph_id ), TaskType .EMBEDDING ,
144+ State .SUCCESS )
133145 except Exception as e :
134146 max_kb_error .error (f'向量化段落:{ paragraph_id } 出现错误{ str (e )} { traceback .format_exc ()} ' )
135- status = Status .error
147+ ListenerManagement .update_status (QuerySet (Paragraph ).filter (id = paragraph_id ), TaskType .EMBEDDING ,
148+ State .FAILURE )
136149 finally :
137- QuerySet (Paragraph ).filter (id = paragraph_id ).update (** {'status' : status })
138150 max_kb .info (f'结束--->向量化段落:{ paragraph_id } ' )
139151
140152 @staticmethod
141153 def embedding_by_data_list (data_list : List , embedding_model : Embeddings ):
142154 # 批量向量化
143155 VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , lambda : True )
144156
157+ @staticmethod
158+ def get_embedding_paragraph_apply (embedding_model , is_the_task_interrupted , post_apply = lambda : None ):
159+ def embedding_paragraph_apply (paragraph_list ):
160+ for paragraph in paragraph_list :
161+ if is_the_task_interrupted ():
162+ break
163+ ListenerManagement .embedding_by_paragraph (str (paragraph .get ('id' )), embedding_model )
164+ post_apply ()
165+
166+ return embedding_paragraph_apply
167+
168+ @staticmethod
169+ def get_aggregation_document_status (document_id ):
170+ def aggregation_document_status ():
171+ sql = get_file_content (
172+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_document_status_meta.sql' ))
173+ native_update ({'document_custom_sql' : QuerySet (Document ).filter (id = document_id ),
174+ 'default_sql' : QuerySet (Document ).filter (id = document_id )}, sql , with_table_name = True )
175+
176+ return aggregation_document_status
177+
178+ @staticmethod
179+ def get_aggregation_document_status_by_dataset_id (dataset_id ):
180+ def aggregation_document_status ():
181+ sql = get_file_content (
182+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_document_status_meta.sql' ))
183+ native_update ({'document_custom_sql' : QuerySet (Document ).filter (dataset_id = dataset_id ),
184+ 'default_sql' : QuerySet (Document ).filter (dataset_id = dataset_id )}, sql )
185+
186+ return aggregation_document_status
187+
188+ @staticmethod
189+ def get_aggregation_document_status_by_query_set (queryset ):
190+ def aggregation_document_status ():
191+ sql = get_file_content (
192+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_document_status_meta.sql' ))
193+ native_update ({'document_custom_sql' : queryset , 'default_sql' : queryset }, sql )
194+
195+ return aggregation_document_status
196+
197+ @staticmethod
198+ def post_update_document_status (document_id , task_type : TaskType ):
199+ _document = QuerySet (Document ).filter (id = document_id ).first ()
200+
201+ status = Status (_document .status )
202+ if status [task_type ] == State .REVOKE :
203+ status [task_type ] = State .REVOKED
204+ else :
205+ status [task_type ] = State .SUCCESS
206+ for item in _document .status_meta .get ('aggs' , []):
207+ agg_status = item .get ('status' )
208+ agg_count = item .get ('count' )
209+ if Status (agg_status )[task_type ] == State .FAILURE and agg_count > 0 :
210+ status [task_type ] = State .FAILURE
211+ ListenerManagement .update_status (QuerySet (Document ).filter (id = document_id ), task_type , status [task_type ])
212+
213+ ListenerManagement .update_status (QuerySet (Paragraph ).annotate (
214+ reversed_status = Reverse ('status' ),
215+ task_type_status = Substr ('reversed_status' , task_type .value ,
216+ task_type .value ),
217+ ).filter (task_type_status = State .REVOKE .value ).filter (document_id = document_id ).values ('id' ),
218+ task_type ,
219+ State .REVOKED )
220+
221+ @staticmethod
222+ def update_status (query_set : QuerySet , taskType : TaskType , state : State ):
223+ exec_sql = get_file_content (
224+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_paragraph_status.sql' ))
225+ bit_number = len (TaskType )
226+ up_index = taskType .value - 1
227+ next_index = taskType .value + 1
228+ current_index = taskType .value
229+ status_number = state .value
230+ params_dict = {'${bit_number}' : bit_number , '${up_index}' : up_index ,
231+ '${status_number}' : status_number , '${next_index}' : next_index ,
232+ '${table_name}' : query_set .model ._meta .db_table , '${current_index}' : current_index }
233+ for key in params_dict :
234+ _value_ = params_dict [key ]
235+ exec_sql = exec_sql .replace (key , str (_value_ ))
236+ lock .acquire ()
237+ try :
238+ native_update (query_set , exec_sql )
239+ finally :
240+ lock .release ()
241+
145242 @staticmethod
146243 def embedding_by_document (document_id , embedding_model : Embeddings ):
147244 """
@@ -153,33 +250,29 @@ def embedding_by_document(document_id, embedding_model: Embeddings):
153250 if not try_lock ('embedding' + str (document_id )):
154251 return
155252 max_kb .info (f"开始--->向量化文档:{ document_id } " )
156- QuerySet (Document ).filter (id = document_id ).update (** {'status' : Status .embedding })
157- QuerySet (Paragraph ).filter (document_id = document_id ).update (** {'status' : Status .embedding })
158- status = Status .success
253+ # 批量修改状态为PADDING
254+ ListenerManagement .update_status (QuerySet (Document ).filter (id = document_id ), TaskType .EMBEDDING , State .STARTED )
159255 try :
160- data_list = native_search (
161- {'problem' : QuerySet (
162- get_dynamics_model ({'paragraph.document_id' : django .db .models .CharField ()})).filter (
163- ** {'paragraph.document_id' : document_id }),
164- 'paragraph' : QuerySet (Paragraph ).filter (document_id = document_id )},
165- select_string = get_file_content (
166- os .path .join (PROJECT_DIR , "apps" , "common" , 'sql' , 'list_embedding_text.sql' )))
167256 # 删除文档向量数据
168257 VectorStore .get_embedding_vector ().delete_by_document_id (document_id )
169258
170- def is_save_function ():
171- return QuerySet (Document ).filter (id = document_id ).exists ()
172-
173- # 批量向量化
174- VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , is_save_function )
259+ def is_the_task_interrupted ():
260+ document = QuerySet (Document ).filter (id = document_id ).first ()
261+ if document is None or Status (document .status )[TaskType .EMBEDDING ] == State .REVOKE :
262+ return True
263+ return False
264+
265+ # 根据段落进行向量化处理
266+ page (QuerySet (Paragraph ).filter (document_id = document_id ).values ('id' ), 5 ,
267+ ListenerManagement .get_embedding_paragraph_apply (embedding_model , is_the_task_interrupted ,
268+ ListenerManagement .get_aggregation_document_status (
269+ document_id )),
270+ is_the_task_interrupted )
175271 except Exception as e :
176272 max_kb_error .error (f'向量化文档:{ document_id } 出现错误{ str (e )} { traceback .format_exc ()} ' )
177- status = Status .error
178273 finally :
179- # 修改状态
180- QuerySet (Document ).filter (id = document_id ).update (
181- ** {'status' : status , 'update_time' : datetime .datetime .now ()})
182- QuerySet (Paragraph ).filter (document_id = document_id ).update (** {'status' : status })
274+ ListenerManagement .post_update_document_status (document_id , TaskType .EMBEDDING )
275+ ListenerManagement .get_aggregation_document_status (document_id )()
183276 max_kb .info (f"结束--->向量化文档:{ document_id } " )
184277 un_lock ('embedding' + str (document_id ))
185278
0 commit comments