1212
1313from common .db .search import native_search
1414from common .event import ListenerManagement
15+ from common .event .common import work_thread_pool
1516from common .exception .app_exception import AppApiException
1617from common .handle .impl .text .csv_split_handle import CsvSplitHandle
1718from common .handle .impl .text .doc_split_handle import DocSplitHandle
2122from common .handle .impl .text .xls_split_handle import XlsSplitHandle
2223from common .handle .impl .text .xlsx_split_handle import XlsxSplitHandle
2324from common .handle .impl .text .zip_split_handle import ZipSplitHandle
24- from common .utils .common import post , get_file_content
25+ from common .utils .common import post , get_file_content , bulk_create_in_batches
2526from knowledge .models import Knowledge , Paragraph , Problem , Document , KnowledgeType , ProblemParagraphMapping , State , \
2627 TaskType , File
27- from knowledge .serializers .common import ProblemParagraphManage
28- from knowledge .serializers .paragraph import ParagraphSerializers , ParagraphInstanceSerializer
29- from knowledge .task import embedding_by_document
28+ from knowledge .serializers .common import ProblemParagraphManage , BatchSerializer
29+ from knowledge .serializers .paragraph import ParagraphSerializers , ParagraphInstanceSerializer , \
30+ delete_problems_and_mappings
31+ from knowledge .task import embedding_by_document , delete_embedding_by_document_list
3032from maxkb .const import PROJECT_DIR
3133
3234default_split_handle = TextSplitHandle ()
4244]
4345
4446
47+ class BatchCancelInstanceSerializer (serializers .Serializer ):
48+ id_list = serializers .ListField (required = True , child = serializers .UUIDField (required = True ), label = ('id list' ))
49+ type = serializers .IntegerField (required = True , label = _ ('task type' ))
50+
51+ def is_valid (self , * , raise_exception = False ):
52+ super ().is_valid (raise_exception = True )
53+ _type = self .data .get ('type' )
54+ try :
55+ TaskType (_type )
56+ except Exception as e :
57+ raise AppApiException (500 , _ ('task type not support' ))
58+
59+
4560class DocumentInstanceSerializer (serializers .Serializer ):
4661 name = serializers .CharField (required = True , label = _ ('document name' ), max_length = 128 , min_length = 1 )
4762 paragraphs = ParagraphInstanceSerializer (required = False , many = True , allow_null = True )
@@ -65,6 +80,17 @@ class DocumentSplitRequest(serializers.Serializer):
6580 with_filter = serializers .BooleanField (required = False , label = _ ('Auto Clean' ))
6681
6782
83+ class DocumentBatchRequest (serializers .Serializer ):
84+ file = serializers .ListField (required = True , label = _ ('file list' ))
85+ limit = serializers .IntegerField (required = False , label = _ ('limit' ))
86+ patterns = serializers .ListField (
87+ required = False ,
88+ child = serializers .CharField (required = True , label = _ ('patterns' )),
89+ label = _ ('patterns' )
90+ )
91+ with_filter = serializers .BooleanField (required = False , label = _ ('Auto Clean' ))
92+
93+
6894class DocumentSerializers (serializers .Serializer ):
6995 class Operate (serializers .Serializer ):
7096 document_id = serializers .UUIDField (required = True , label = _ ('document id' ))
@@ -264,6 +290,150 @@ def file_to_paragraph(self, file, pattern_list: List, with_filter: bool, limit:
264290 return result
265291 return [result ]
266292
293+ class Batch (serializers .Serializer ):
294+ workspace_id = serializers .UUIDField (required = True , label = _ ('workspace id' ))
295+ knowledge_id = serializers .UUIDField (required = True , label = _ ('knowledge id' ))
296+
297+ @staticmethod
298+ def post_embedding (document_list , knowledge_id ):
299+ for document_dict in document_list :
300+ DocumentSerializers .Operate (
301+ data = {'knowledge_id' : knowledge_id , 'document_id' : document_dict .get ('id' )}).refresh ()
302+ return document_list
303+
304+ @post (post_function = post_embedding )
305+ @transaction .atomic
306+ def batch_save (self , instance_list : List [Dict ], with_valid = True ):
307+ if with_valid :
308+ self .is_valid (raise_exception = True )
309+ DocumentInstanceSerializer (many = True , data = instance_list ).is_valid (raise_exception = True )
310+ knowledge_id = self .data .get ("knowledge_id" )
311+ document_model_list = []
312+ paragraph_model_list = []
313+ problem_paragraph_object_list = []
314+ # 插入文档
315+ for document in instance_list :
316+ document_paragraph_dict_model = DocumentSerializers .Create .get_document_paragraph_model (knowledge_id ,
317+ document )
318+ document_model_list .append (document_paragraph_dict_model .get ('document' ))
319+ for paragraph in document_paragraph_dict_model .get ('paragraph_model_list' ):
320+ paragraph_model_list .append (paragraph )
321+ for problem_paragraph_object in document_paragraph_dict_model .get ('problem_paragraph_object_list' ):
322+ problem_paragraph_object_list .append (problem_paragraph_object )
323+
324+ problem_model_list , problem_paragraph_mapping_list = (
325+ ProblemParagraphManage (problem_paragraph_object_list , knowledge_id ).to_problem_model_list ()
326+ )
327+ # 插入文档
328+ QuerySet (Document ).bulk_create (document_model_list ) if len (document_model_list ) > 0 else None
329+ # 批量插入段落
330+ bulk_create_in_batches (Paragraph , paragraph_model_list , batch_size = 1000 )
331+ # 批量插入问题
332+ bulk_create_in_batches (Problem , problem_model_list , batch_size = 1000 )
333+ # 批量插入关联问题
334+ bulk_create_in_batches (ProblemParagraphMapping , problem_paragraph_mapping_list , batch_size = 1000 )
335+ # 查询文档
336+ query_set = QuerySet (model = Document )
337+ if len (document_model_list ) == 0 :
338+ return [], knowledge_id
339+ query_set = query_set .filter (** {'id__in' : [d .id for d in document_model_list ]})
340+ return native_search (
341+ {
342+ 'document_custom_sql' : query_set ,
343+ 'order_by_query' : QuerySet (Document ).order_by ('-create_time' , 'id' )
344+ },
345+ select_string = get_file_content (
346+ os .path .join (PROJECT_DIR , "apps" , "knowledge" , 'sql' , 'list_document.sql' )
347+ ),
348+ with_search_one = False
349+ ), knowledge_id
350+
351+ @staticmethod
352+ def _batch_sync (document_id_list : List [str ]):
353+ for document_id in document_id_list :
354+ DocumentSerializers .Sync (data = {'document_id' : document_id }).sync ()
355+
356+ def batch_sync (self , instance : Dict , with_valid = True ):
357+ if with_valid :
358+ BatchSerializer (data = instance ).is_valid (model = Document , raise_exception = True )
359+ self .is_valid (raise_exception = True )
360+ # 异步同步
361+ work_thread_pool .submit (self ._batch_sync , instance .get ('id_list' ))
362+ return True
363+
364+ @transaction .atomic
365+ def batch_delete (self , instance : Dict , with_valid = True ):
366+ if with_valid :
367+ BatchSerializer (data = instance ).is_valid (model = Document , raise_exception = True )
368+ self .is_valid (raise_exception = True )
369+ document_id_list = instance .get ("id_list" )
370+ QuerySet (Document ).filter (id__in = document_id_list ).delete ()
371+ QuerySet (Paragraph ).filter (document_id__in = document_id_list ).delete ()
372+ delete_problems_and_mappings (document_id_list )
373+ # 删除向量库
374+ delete_embedding_by_document_list (document_id_list )
375+ return True
376+
377+ def batch_cancel (self , instance : Dict , with_valid = True ):
378+ if with_valid :
379+ self .is_valid (raise_exception = True )
380+ BatchCancelInstanceSerializer (data = instance ).is_valid (raise_exception = True )
381+ document_id_list = instance .get ("id_list" )
382+ ListenerManagement .update_status (
383+ QuerySet (Paragraph ).annotate (
384+ reversed_status = Reverse ('status' ),
385+ task_type_status = Substr ('reversed_status' , TaskType (instance .get ('type' )).value , 1 ),
386+ ).filter (
387+ task_type_status__in = [State .PENDING .value , State .STARTED .value ]
388+ ).filter (
389+ document_id__in = document_id_list
390+ ).values ('id' ),
391+ TaskType (instance .get ('type' )),
392+ State .REVOKE
393+ )
394+ ListenerManagement .update_status (
395+ QuerySet (Document ).annotate (
396+ reversed_status = Reverse ('status' ),
397+ task_type_status = Substr ('reversed_status' , TaskType (instance .get ('type' )).value , 1 ),
398+ ).filter (
399+ task_type_status__in = [State .PENDING .value , State .STARTED .value ]
400+ ).filter (
401+ id__in = document_id_list
402+ ).values ('id' ),
403+ TaskType (instance .get ('type' )),
404+ State .REVOKE
405+ )
406+
407+ def batch_edit_hit_handling (self , instance : Dict , with_valid = True ):
408+ if with_valid :
409+ BatchSerializer (data = instance ).is_valid (model = Document , raise_exception = True )
410+ hit_handling_method = instance .get ('hit_handling_method' )
411+ if hit_handling_method is None :
412+ raise AppApiException (500 , _ ('Hit handling method is required' ))
413+ if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return' :
414+ raise AppApiException (500 , _ ('The hit processing method must be directly_return|optimization' ))
415+ self .is_valid (raise_exception = True )
416+ document_id_list = instance .get ("id_list" )
417+ hit_handling_method = instance .get ('hit_handling_method' )
418+ directly_return_similarity = instance .get ('directly_return_similarity' )
419+ update_dict = {'hit_handling_method' : hit_handling_method }
420+ if directly_return_similarity is not None :
421+ update_dict ['directly_return_similarity' ] = directly_return_similarity
422+ QuerySet (Document ).filter (id__in = document_id_list ).update (** update_dict )
423+
424+ def batch_refresh (self , instance : Dict , with_valid = True ):
425+ if with_valid :
426+ self .is_valid (raise_exception = True )
427+ document_id_list = instance .get ("id_list" )
428+ state_list = instance .get ("state_list" )
429+ knowledge_id = self .data .get ('knowledge_id' )
430+ for document_id in document_id_list :
431+ try :
432+ DocumentSerializers .Operate (
433+ data = {'knowledge_id' : knowledge_id , 'document_id' : document_id }).refresh (state_list )
434+ except AlreadyQueued as e :
435+ pass
436+
267437
268438class FileBufferHandle :
269439 buffer = None
0 commit comments