Skip to content

Commit 0d3eb43

Browse files
committed
feat: implement batch processing for document creation, synchronization, and deletion
1 parent 43bef21 commit 0d3eb43

File tree

5 files changed

+272
-14
lines changed

5 files changed

+272
-14
lines changed

apps/common/utils/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,10 @@ def parse_md_image(content: str):
263263
image_list = [match.group() for match in matches]
264264
return image_list
265265

266+
def bulk_create_in_batches(model, data, batch_size=1000):
267+
if len(data) == 0:
268+
return
269+
for i in range(0, len(data), batch_size):
270+
batch = data[i:i + batch_size]
271+
model.objects.bulk_create(batch)
272+

apps/knowledge/api/document.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,11 @@
22
from drf_spectacular.utils import OpenApiParameter
33

44
from common.mixins.api_mixin import APIMixin
5-
from common.result import DefaultResultSerializer, ResultSerializer
5+
from common.result import DefaultResultSerializer
6+
from knowledge.serializers.common import BatchSerializer
67
from knowledge.serializers.document import DocumentCreateRequest
78

89

9-
class DocumentCreateResponse(ResultSerializer):
10-
@staticmethod
11-
def get_data():
12-
return DefaultResultSerializer()
13-
14-
1510
class DocumentCreateAPI(APIMixin):
1611
@staticmethod
1712
def get_parameters():
@@ -31,7 +26,7 @@ def get_request():
3126

3227
@staticmethod
3328
def get_response():
34-
return DocumentCreateResponse
29+
return DefaultResultSerializer
3530

3631

3732
class DocumentSplitAPI(APIMixin):
@@ -75,3 +70,31 @@ def get_parameters():
7570
),
7671
]
7772

73+
74+
class DocumentBatchAPI(APIMixin):
75+
@staticmethod
76+
def get_parameters():
77+
return [
78+
OpenApiParameter(
79+
name="workspace_id",
80+
description="工作空间id",
81+
type=OpenApiTypes.STR,
82+
location='path',
83+
required=True,
84+
),
85+
OpenApiParameter(
86+
name="knowledge_id",
87+
description="知识库id",
88+
type=OpenApiTypes.STR,
89+
location='path',
90+
required=True,
91+
),
92+
]
93+
94+
@staticmethod
95+
def get_request():
96+
return BatchSerializer
97+
98+
@staticmethod
99+
def get_response():
100+
return DefaultResultSerializer

apps/knowledge/serializers/document.py

Lines changed: 174 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from common.db.search import native_search
1414
from common.event import ListenerManagement
15+
from common.event.common import work_thread_pool
1516
from common.exception.app_exception import AppApiException
1617
from common.handle.impl.text.csv_split_handle import CsvSplitHandle
1718
from common.handle.impl.text.doc_split_handle import DocSplitHandle
@@ -21,12 +22,13 @@
2122
from common.handle.impl.text.xls_split_handle import XlsSplitHandle
2223
from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
2324
from common.handle.impl.text.zip_split_handle import ZipSplitHandle
24-
from common.utils.common import post, get_file_content
25+
from common.utils.common import post, get_file_content, bulk_create_in_batches
2526
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
2627
TaskType, File
27-
from knowledge.serializers.common import ProblemParagraphManage
28-
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer
29-
from knowledge.task import embedding_by_document
28+
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer
29+
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
30+
delete_problems_and_mappings
31+
from knowledge.task import embedding_by_document, delete_embedding_by_document_list
3032
from maxkb.const import PROJECT_DIR
3133

3234
default_split_handle = TextSplitHandle()
@@ -42,6 +44,19 @@
4244
]
4345

4446

47+
class BatchCancelInstanceSerializer(serializers.Serializer):
48+
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
49+
type = serializers.IntegerField(required=True, label=_('task type'))
50+
51+
def is_valid(self, *, raise_exception=False):
52+
super().is_valid(raise_exception=True)
53+
_type = self.data.get('type')
54+
try:
55+
TaskType(_type)
56+
except Exception as e:
57+
raise AppApiException(500, _('task type not support'))
58+
59+
4560
class DocumentInstanceSerializer(serializers.Serializer):
4661
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
4762
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
@@ -65,6 +80,17 @@ class DocumentSplitRequest(serializers.Serializer):
6580
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
6681

6782

83+
class DocumentBatchRequest(serializers.Serializer):
84+
file = serializers.ListField(required=True, label=_('file list'))
85+
limit = serializers.IntegerField(required=False, label=_('limit'))
86+
patterns = serializers.ListField(
87+
required=False,
88+
child=serializers.CharField(required=True, label=_('patterns')),
89+
label=_('patterns')
90+
)
91+
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
92+
93+
6894
class DocumentSerializers(serializers.Serializer):
6995
class Operate(serializers.Serializer):
7096
document_id = serializers.UUIDField(required=True, label=_('document id'))
@@ -264,6 +290,150 @@ def file_to_paragraph(self, file, pattern_list: List, with_filter: bool, limit:
264290
return result
265291
return [result]
266292

293+
class Batch(serializers.Serializer):
294+
workspace_id = serializers.UUIDField(required=True, label=_('workspace id'))
295+
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
296+
297+
@staticmethod
298+
def post_embedding(document_list, knowledge_id):
299+
for document_dict in document_list:
300+
DocumentSerializers.Operate(
301+
data={'knowledge_id': knowledge_id, 'document_id': document_dict.get('id')}).refresh()
302+
return document_list
303+
304+
@post(post_function=post_embedding)
305+
@transaction.atomic
306+
def batch_save(self, instance_list: List[Dict], with_valid=True):
307+
if with_valid:
308+
self.is_valid(raise_exception=True)
309+
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
310+
knowledge_id = self.data.get("knowledge_id")
311+
document_model_list = []
312+
paragraph_model_list = []
313+
problem_paragraph_object_list = []
314+
# 插入文档
315+
for document in instance_list:
316+
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id,
317+
document)
318+
document_model_list.append(document_paragraph_dict_model.get('document'))
319+
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
320+
paragraph_model_list.append(paragraph)
321+
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
322+
problem_paragraph_object_list.append(problem_paragraph_object)
323+
324+
problem_model_list, problem_paragraph_mapping_list = (
325+
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()
326+
)
327+
# 插入文档
328+
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
329+
# 批量插入段落
330+
bulk_create_in_batches(Paragraph, paragraph_model_list, batch_size=1000)
331+
# 批量插入问题
332+
bulk_create_in_batches(Problem, problem_model_list, batch_size=1000)
333+
# 批量插入关联问题
334+
bulk_create_in_batches(ProblemParagraphMapping, problem_paragraph_mapping_list, batch_size=1000)
335+
# 查询文档
336+
query_set = QuerySet(model=Document)
337+
if len(document_model_list) == 0:
338+
return [], knowledge_id
339+
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
340+
return native_search(
341+
{
342+
'document_custom_sql': query_set,
343+
'order_by_query': QuerySet(Document).order_by('-create_time', 'id')
344+
},
345+
select_string=get_file_content(
346+
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')
347+
),
348+
with_search_one=False
349+
), knowledge_id
350+
351+
@staticmethod
352+
def _batch_sync(document_id_list: List[str]):
353+
for document_id in document_id_list:
354+
DocumentSerializers.Sync(data={'document_id': document_id}).sync()
355+
356+
def batch_sync(self, instance: Dict, with_valid=True):
357+
if with_valid:
358+
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
359+
self.is_valid(raise_exception=True)
360+
# 异步同步
361+
work_thread_pool.submit(self._batch_sync, instance.get('id_list'))
362+
return True
363+
364+
@transaction.atomic
365+
def batch_delete(self, instance: Dict, with_valid=True):
366+
if with_valid:
367+
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
368+
self.is_valid(raise_exception=True)
369+
document_id_list = instance.get("id_list")
370+
QuerySet(Document).filter(id__in=document_id_list).delete()
371+
QuerySet(Paragraph).filter(document_id__in=document_id_list).delete()
372+
delete_problems_and_mappings(document_id_list)
373+
# 删除向量库
374+
delete_embedding_by_document_list(document_id_list)
375+
return True
376+
377+
def batch_cancel(self, instance: Dict, with_valid=True):
378+
if with_valid:
379+
self.is_valid(raise_exception=True)
380+
BatchCancelInstanceSerializer(data=instance).is_valid(raise_exception=True)
381+
document_id_list = instance.get("id_list")
382+
ListenerManagement.update_status(
383+
QuerySet(Paragraph).annotate(
384+
reversed_status=Reverse('status'),
385+
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
386+
).filter(
387+
task_type_status__in=[State.PENDING.value, State.STARTED.value]
388+
).filter(
389+
document_id__in=document_id_list
390+
).values('id'),
391+
TaskType(instance.get('type')),
392+
State.REVOKE
393+
)
394+
ListenerManagement.update_status(
395+
QuerySet(Document).annotate(
396+
reversed_status=Reverse('status'),
397+
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
398+
).filter(
399+
task_type_status__in=[State.PENDING.value, State.STARTED.value]
400+
).filter(
401+
id__in=document_id_list
402+
).values('id'),
403+
TaskType(instance.get('type')),
404+
State.REVOKE
405+
)
406+
407+
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
408+
if with_valid:
409+
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
410+
hit_handling_method = instance.get('hit_handling_method')
411+
if hit_handling_method is None:
412+
raise AppApiException(500, _('Hit handling method is required'))
413+
if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return':
414+
raise AppApiException(500, _('The hit processing method must be directly_return|optimization'))
415+
self.is_valid(raise_exception=True)
416+
document_id_list = instance.get("id_list")
417+
hit_handling_method = instance.get('hit_handling_method')
418+
directly_return_similarity = instance.get('directly_return_similarity')
419+
update_dict = {'hit_handling_method': hit_handling_method}
420+
if directly_return_similarity is not None:
421+
update_dict['directly_return_similarity'] = directly_return_similarity
422+
QuerySet(Document).filter(id__in=document_id_list).update(**update_dict)
423+
424+
def batch_refresh(self, instance: Dict, with_valid=True):
425+
if with_valid:
426+
self.is_valid(raise_exception=True)
427+
document_id_list = instance.get("id_list")
428+
state_list = instance.get("state_list")
429+
knowledge_id = self.data.get('knowledge_id')
430+
for document_id in document_id_list:
431+
try:
432+
DocumentSerializers.Operate(
433+
data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh(state_list)
434+
except AlreadyQueued as e:
435+
pass
436+
267437

268438
class FileBufferHandle:
269439
buffer = None

apps/knowledge/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99
path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()),
1010
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
1111
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
12+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()),
1213
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
1314
]

apps/knowledge/views/document.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
from common.auth import TokenAuth
88
from common.auth.authentication import has_permissions
9-
from common.constants.permission_constants import PermissionConstants, CompareConstants
9+
from common.constants.permission_constants import PermissionConstants
1010
from common.result import result
11-
from knowledge.api.document import DocumentSplitAPI
11+
from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI
1212
from knowledge.api.knowledge import KnowledgeTreeReadAPI
1313
from knowledge.serializers.document import DocumentSerializers
1414
from knowledge.serializers.knowledge import KnowledgeSerializer
@@ -68,3 +68,60 @@ def post(self, request: Request, workspace_id: str, knowledge_id: str):
6868
'workspace_id': workspace_id,
6969
'knowledge_id': knowledge_id,
7070
}).parse(split_data))
71+
72+
class Batch(APIView):
73+
authentication_classes = [TokenAuth]
74+
75+
@extend_schema(
76+
methods=['POST'],
77+
description=_('Create documents in batches'),
78+
operation_id=_('Create documents in batches'),
79+
request=DocumentBatchAPI.get_request(),
80+
parameters=DocumentBatchAPI.get_parameters(),
81+
responses=DocumentBatchAPI.get_response(),
82+
tags=[_('Knowledge Base/Documentation')]
83+
)
84+
@has_permissions([
85+
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
86+
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
87+
])
88+
def post(self, request: Request, workspace_id: str, knowledge_id: str):
89+
return result.success(DocumentSerializers.Batch(
90+
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
91+
).batch_save(request.data))
92+
93+
@extend_schema(
94+
methods=['PUT'],
95+
description=_('Batch sync documents'),
96+
operation_id=_('Batch sync documents'),
97+
request=DocumentBatchAPI.get_request(),
98+
parameters=DocumentBatchAPI.get_parameters(),
99+
responses=DocumentBatchAPI.get_response(),
100+
tags=[_('Knowledge Base/Documentation')]
101+
)
102+
@has_permissions([
103+
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
104+
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
105+
])
106+
def put(self, request: Request, workspace_id: str, knowledge_id: str):
107+
return result.success(DocumentSerializers.Batch(
108+
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
109+
).batch_sync(request.data))
110+
111+
@extend_schema(
112+
methods=['DELETE'],
113+
description=_('Delete documents in batches'),
114+
operation_id=_('Delete documents in batches'),
115+
request=DocumentBatchAPI.get_request(),
116+
parameters=DocumentBatchAPI.get_parameters(),
117+
responses=DocumentBatchAPI.get_response(),
118+
tags=[_('Knowledge Base/Documentation')]
119+
)
120+
@has_permissions([
121+
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
122+
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
123+
])
124+
def delete(self, request: Request, workspace_id: str, knowledge_id: str):
125+
return result.success(DocumentSerializers.Batch(
126+
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
127+
).batch_delete(request.data))

0 commit comments

Comments
 (0)