Skip to content

Commit ac698e0

Browse files
committed
feat: add BatchRefresh and BatchGenerateRelated APIs for document processing
1 parent 54836d5 commit ac698e0

File tree

4 files changed

+156
-10
lines changed

4 files changed

+156
-10
lines changed

apps/knowledge/api/document.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from common.result import DefaultResultSerializer
66
from knowledge.serializers.common import BatchSerializer
77
from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer, \
8-
CancelInstanceSerializer, BatchCancelInstanceSerializer, DocumentRefreshSerializer, BatchEditHitHandlingSerializer
8+
CancelInstanceSerializer, BatchCancelInstanceSerializer, DocumentRefreshSerializer, BatchEditHitHandlingSerializer, \
9+
DocumentBatchRefreshSerializer, DocumentBatchGenerateRelatedSerializer
910

1011

1112
class DocumentSplitAPI(APIMixin):
@@ -356,3 +357,52 @@ def get_parameters():
356357
@staticmethod
357358
def get_response():
358359
return DefaultResultSerializer
360+
361+
362+
class BatchRefreshAPI(APIMixin):
363+
@staticmethod
364+
def get_parameters():
365+
return [
366+
OpenApiParameter(
367+
name="workspace_id",
368+
description="工作空间id",
369+
type=OpenApiTypes.STR,
370+
location='path',
371+
required=True,
372+
),
373+
OpenApiParameter(
374+
name="knowledge_id",
375+
description="知识库id",
376+
type=OpenApiTypes.STR,
377+
location='path',
378+
required=True,
379+
),
380+
]
381+
382+
@staticmethod
383+
def get_request():
384+
return DocumentBatchRefreshSerializer
385+
386+
class BatchGenerateRelatedAPI(APIMixin):
387+
@staticmethod
388+
def get_parameters():
389+
return [
390+
OpenApiParameter(
391+
name="workspace_id",
392+
description="工作空间id",
393+
type=OpenApiTypes.STR,
394+
location='path',
395+
required=True,
396+
),
397+
OpenApiParameter(
398+
name="knowledge_id",
399+
description="知识库id",
400+
type=OpenApiTypes.STR,
401+
location='path',
402+
required=True,
403+
),
404+
]
405+
406+
@staticmethod
407+
def get_request():
408+
return DocumentBatchGenerateRelatedSerializer

apps/knowledge/serializers/document.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
delete_problems_and_mappings
4545
from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
4646
delete_embedding_by_document
47+
from knowledge.task.generate import generate_related_by_document_id
4748
from knowledge.task.sync import sync_web_document
4849
from maxkb.const import PROJECT_DIR
4950

@@ -109,17 +110,17 @@ class DocumentEditInstanceSerializer(serializers.Serializer):
109110

110111
@staticmethod
111112
def get_meta_valid_map():
112-
dataset_meta_valid_map = {
113+
knowledge_meta_valid_map = {
113114
KnowledgeType.BASE: MetaSerializer.BaseMeta,
114115
KnowledgeType.WEB: MetaSerializer.WebMeta
115116
}
116-
return dataset_meta_valid_map
117+
return knowledge_meta_valid_map
117118

118119
def is_valid(self, *, document: Document = None):
119120
super().is_valid(raise_exception=True)
120121
if 'meta' in self.data and self.data.get('meta') is not None:
121-
dataset_meta_valid_map = self.get_meta_valid_map()
122-
valid_class = dataset_meta_valid_map.get(document.type)
122+
knowledge_meta_valid_map = self.get_meta_valid_map()
123+
valid_class = knowledge_meta_valid_map.get(document.type)
123124
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
124125

125126

@@ -154,6 +155,18 @@ class DocumentRefreshSerializer(serializers.Serializer):
154155
state_list = serializers.ListField(required=True, label=_('state list'))
155156

156157

158+
class DocumentBatchRefreshSerializer(serializers.Serializer):
159+
id_list = serializers.ListField(required=True, label=_('id list'))
160+
state_list = serializers.ListField(required=True, label=_('state list'))
161+
162+
163+
class DocumentBatchGenerateRelatedSerializer(serializers.Serializer):
164+
document_id_list = serializers.ListField(required=True, label=_('document id list'))
165+
model_id = serializers.UUIDField(required=True, label=_('model id'))
166+
prompt = serializers.CharField(required=True, label=_('prompt'))
167+
state_list = serializers.ListField(required=True, label=_('state list'))
168+
169+
157170
class BatchEditHitHandlingSerializer(serializers.Serializer):
158171
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
159172
hit_handling_method = serializers.CharField(required=True, label=_('hit handling method'))
@@ -521,26 +534,28 @@ def save_web(self, instance: Dict, with_valid=True):
521534
if with_valid:
522535
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
523536
self.is_valid(raise_exception=True)
524-
dataset_id = self.data.get('dataset_id')
537+
knowledge_id = self.data.get('knowledge_id')
525538
source_url_list = instance.get('source_url_list')
526539
selector = instance.get('selector')
527-
sync_web_document.delay(dataset_id, source_url_list, selector)
540+
sync_web_document.delay(knowledge_id, source_url_list, selector)
528541

529542
def save_qa(self, instance: Dict, with_valid=True):
530543
if with_valid:
531544
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
532545
self.is_valid(raise_exception=True)
533546
file_list = instance.get('file_list')
534547
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
535-
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
548+
return DocumentSerializers.Batch(data={'knowledge_id': self.data.get('knowledge_id')}).batch_save(
549+
document_list)
536550

537551
def save_table(self, instance: Dict, with_valid=True):
538552
if with_valid:
539553
DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
540554
self.is_valid(raise_exception=True)
541555
file_list = instance.get('file_list')
542556
document_list = flat_map([self.parse_table_file(file) for file in file_list])
543-
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
557+
return DocumentSerializers.Batch(data={'knowledge_id': self.data.get('knowledge_id')}).batch_save(
558+
document_list)
544559

545560
def parse_qa_file(self, file):
546561
get_buffer = FileBufferHandle().get_buffer
@@ -788,6 +803,42 @@ def batch_refresh(self, instance: Dict, with_valid=True):
788803
except AlreadyQueued as e:
789804
pass
790805

806+
class BatchGenerateRelated(serializers.Serializer):
807+
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
808+
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
809+
810+
def batch_generate_related(self, instance: Dict, with_valid=True):
811+
if with_valid:
812+
self.is_valid(raise_exception=True)
813+
document_id_list = instance.get("document_id_list")
814+
model_id = instance.get("model_id")
815+
prompt = instance.get("prompt")
816+
state_list = instance.get('state_list')
817+
ListenerManagement.update_status(
818+
QuerySet(Document).filter(id__in=document_id_list),
819+
TaskType.GENERATE_PROBLEM,
820+
State.PENDING
821+
)
822+
ListenerManagement.update_status(
823+
QuerySet(Paragraph).annotate(
824+
reversed_status=Reverse('status'),
825+
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
826+
1),
827+
).filter(
828+
task_type_status__in=state_list, document_id__in=document_id_list
829+
)
830+
.values('id'),
831+
TaskType.GENERATE_PROBLEM,
832+
State.PENDING
833+
)
834+
ListenerManagement.get_aggregation_document_status_by_query_set(
835+
QuerySet(Document).filter(id__in=document_id_list))()
836+
try:
837+
for document_id in document_id_list:
838+
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
839+
except AlreadyQueued as e:
840+
pass
841+
791842

792843
class FileBufferHandle:
793844
buffer = None

apps/knowledge/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch_create', views.DocumentView.BatchCreate.as_view()),
1919
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch_sync', views.DocumentView.BatchSync.as_view()),
2020
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch_delete', views.DocumentView.BatchDelete.as_view()),
21+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch_refresh', views.DocumentView.BatchRefresh.as_view()),
22+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch_generate_related', views.DocumentView.BatchGenerateRelated.as_view()),
2123
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/web', views.WebDocumentView.as_view()),
2224
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/qa', views.QaDocumentView.as_view()),
2325
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/table', views.TableDocumentView.as_view()),

apps/knowledge/views/document.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \
1212
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \
1313
WebDocumentCreateAPI, CancelTaskAPI, BatchCancelTaskAPI, SyncWebAPI, RefreshAPI, BatchEditHitHandlingAPI, \
14-
DocumentTreeReadAPI, DocumentSplitPatternAPI
14+
DocumentTreeReadAPI, DocumentSplitPatternAPI, BatchRefreshAPI, BatchGenerateRelatedAPI
1515
from knowledge.serializers.document import DocumentSerializers
1616

1717

@@ -314,6 +314,49 @@ def put(self, request: Request, workspace_id: str, knowledge_id: str):
314314
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
315315
).batch_delete(request.data))
316316

317+
class BatchRefresh(APIView):
318+
authentication_classes = [TokenAuth]
319+
320+
@extend_schema(
321+
methods=['PUT'],
322+
summary=_('Batch refresh document vector library'),
323+
operation_id=_('Batch refresh document vector library'),
324+
request=BatchRefreshAPI.get_request(),
325+
parameters=BatchRefreshAPI.get_parameters(),
326+
responses=BatchRefreshAPI.get_response(),
327+
tags=[_('Knowledge Base/Documentation')]
328+
)
329+
@has_permissions([
330+
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
331+
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
332+
])
333+
def put(self, request: Request, workspace_id: str, knowledge_id: str):
334+
return result.success(
335+
DocumentSerializers.Batch(
336+
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
337+
).batch_refresh(request.data))
338+
339+
class BatchGenerateRelated(APIView):
340+
authentication_classes = [TokenAuth]
341+
342+
@extend_schema(
343+
methods=['PUT'],
344+
summary=_('Batch refresh document vector library'),
345+
operation_id=_('Batch refresh document vector library'),
346+
request=BatchGenerateRelatedAPI.get_request(),
347+
parameters=BatchGenerateRelatedAPI.get_parameters(),
348+
responses=BatchGenerateRelatedAPI.get_response(),
349+
tags=[_('Knowledge Base/Documentation')]
350+
)
351+
@has_permissions([
352+
PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
353+
PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
354+
])
355+
def put(self, request: Request, workspace_id: str, knowledge_id: str):
356+
return result.success(DocumentSerializers.BatchGenerateRelated(
357+
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
358+
).batch_generate_related(request.data))
359+
317360
class Page(APIView):
318361
authentication_classes = [TokenAuth]
319362

0 commit comments

Comments
 (0)