Skip to content

Commit c3b979d

Browse files
committed
feat: add CancelTaskAPI and batch cancellation endpoints for document tasks
1 parent e702af8 commit c3b979d

File tree

6 files changed

+223
-6
lines changed

6 files changed

+223
-6
lines changed

apps/knowledge/api/document.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from common.mixins.api_mixin import APIMixin
55
from common.result import DefaultResultSerializer
66
from knowledge.serializers.common import BatchSerializer
7-
from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer
7+
from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer, \
8+
CancelInstanceSerializer, BatchCancelInstanceSerializer, DocumentRefreshSerializer, BatchEditHitHandlingSerializer
89

910

1011
class DocumentSplitAPI(APIMixin):
@@ -218,3 +219,50 @@ class WebDocumentCreateAPI(APIMixin):
218219
@staticmethod
219220
def get_request():
220221
return DocumentWebInstanceSerializer
222+
223+
224+
class CancelTaskAPI(DocumentReadAPI):
225+
@staticmethod
226+
def get_request():
227+
return CancelInstanceSerializer
228+
229+
230+
class BatchCancelTaskAPI(DocumentReadAPI):
231+
@staticmethod
232+
def get_request():
233+
return BatchCancelInstanceSerializer
234+
235+
236+
class SyncWebAPI(DocumentReadAPI):
237+
pass
238+
239+
240+
class RefreshAPI(DocumentReadAPI):
241+
@staticmethod
242+
def get_request():
243+
return DocumentRefreshSerializer
244+
245+
246+
class BatchEditHitHandlingAPI(APIMixin):
247+
@staticmethod
248+
def get_parameters():
249+
return [
250+
OpenApiParameter(
251+
name="workspace_id",
252+
description="工作空间id",
253+
type=OpenApiTypes.STR,
254+
location='path',
255+
required=True,
256+
),
257+
OpenApiParameter(
258+
name="knowledge_id",
259+
description="知识库id",
260+
type=OpenApiTypes.STR,
261+
location='path',
262+
required=True,
263+
),
264+
]
265+
266+
@staticmethod
267+
def get_request():
268+
return BatchEditHitHandlingSerializer

apps/knowledge/serializers/document.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565

6666
class BatchCancelInstanceSerializer(serializers.Serializer):
67-
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
67+
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
6868
type = serializers.IntegerField(required=True, label=_('task type'))
6969

7070
def is_valid(self, *, raise_exception=False):
@@ -81,6 +81,18 @@ class DocumentInstanceSerializer(serializers.Serializer):
8181
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
8282

8383

84+
class CancelInstanceSerializer(serializers.Serializer):
85+
type = serializers.IntegerField(required=True, label=_('task type'))
86+
87+
def is_valid(self, *, raise_exception=False):
88+
super().is_valid(raise_exception=True)
89+
_type = self.data.get('type')
90+
try:
91+
TaskType(_type)
92+
except Exception as e:
93+
raise AppApiException(500, _('task type not support'))
94+
95+
8496
class DocumentEditInstanceSerializer(serializers.Serializer):
8597
meta = serializers.DictField(required=False)
8698
name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name'))
@@ -138,6 +150,22 @@ class DocumentInstanceTableSerializer(serializers.Serializer):
138150
child=serializers.FileField(required=True, label=_('file')))
139151

140152

153+
class DocumentRefreshSerializer(serializers.Serializer):
154+
state_list = serializers.ListField(required=True, label=_('state list'))
155+
156+
157+
class BatchEditHitHandlingSerializer(serializers.Serializer):
158+
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
159+
hit_handling_method = serializers.CharField(required=True, label=_('hit handling method'))
160+
directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0,
161+
label=_('directly return similarity'))
162+
163+
def is_valid(self, *, raise_exception=False):
164+
super().is_valid(raise_exception=True)
165+
if self.data.get('hit_handling_method') not in ['optimization', 'directly_return']:
166+
raise AppApiException(500, _('The type only supports optimization|directly_return'))
167+
168+
141169
class DocumentSerializers(serializers.Serializer):
142170
class Query(serializers.Serializer):
143171
# 知识库id
@@ -201,6 +229,8 @@ def page(self, current_page, page_size):
201229
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')))
202230

203231
class Sync(serializers.Serializer):
232+
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
233+
knowledge_id = serializers.UUIDField(required=False, label=_('knowledge id'))
204234
document_id = serializers.UUIDField(required=True, label=_('document id'))
205235

206236
def is_valid(self, *, raise_exception=False):
@@ -320,6 +350,38 @@ def edit(self, instance: Dict, with_valid=False):
320350
_document.save()
321351
return self.one()
322352

353+
def cancel(self, instance, with_valid=True):
354+
if with_valid:
355+
self.is_valid(raise_exception=True)
356+
CancelInstanceSerializer(data=instance).is_valid()
357+
document_id = self.data.get("document_id")
358+
ListenerManagement.update_status(
359+
QuerySet(Paragraph).annotate(
360+
reversed_status=Reverse('status'),
361+
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1),
362+
).filter(
363+
task_type_status__in=[State.PENDING.value, State.STARTED.value]
364+
).filter(
365+
document_id=document_id
366+
).values('id'),
367+
TaskType(instance.get('type')),
368+
State.REVOKE
369+
)
370+
ListenerManagement.update_status(
371+
QuerySet(Document).annotate(
372+
reversed_status=Reverse('status'),
373+
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value,
374+
1),
375+
).filter(
376+
task_type_status__in=[State.PENDING.value, State.STARTED.value]
377+
).filter(
378+
id=document_id
379+
).values('id'),
380+
TaskType(instance.get('type')),
381+
State.REVOKE
382+
)
383+
return True
384+
323385
@transaction.atomic
324386
def delete(self):
325387
document_id = self.data.get("document_id")

apps/knowledge/task/handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from common.utils.fork import ChildLink, Fork
1212
from common.utils.split_model import get_split_model
1313
from knowledge.models.knowledge import KnowledgeType, Document, Knowledge, Status
14-
from knowledge.serializers.document import DocumentSerializers
15-
from knowledge.serializers.paragraph import ParagraphSerializers
1614

1715
max_kb_error = logging.getLogger("max_kb_error")
1816
max_kb = logging.getLogger("max_kb")
1917

2018

2119
def get_save_handler(knowledge_id, selector):
20+
from knowledge.serializers.document import DocumentSerializers
21+
2222
def handler(child_link: ChildLink, response: Fork.Response):
2323
if response.status == 200:
2424
try:
@@ -40,6 +40,8 @@ def handler(child_link: ChildLink, response: Fork.Response):
4040

4141

4242
def get_sync_handler(knowledge_id):
43+
from knowledge.serializers.document import DocumentSerializers
44+
4345
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
4446

4547
def handler(child_link: ChildLink, response: Fork.Response):
@@ -70,6 +72,8 @@ def handler(child_link: ChildLink, response: Fork.Response):
7072

7173

7274
def get_sync_web_document_handler(knowledge_id):
75+
from knowledge.serializers.document import DocumentSerializers
76+
7377
def handler(source_url: str, selector, response: Fork.Response):
7478
if response.status == 200:
7579
try:
@@ -93,6 +97,8 @@ def handler(source_url: str, selector, response: Fork.Response):
9397

9498

9599
def save_problem(knowledge_id, document_id, paragraph_id, problem):
100+
from knowledge.serializers.paragraph import ParagraphSerializers
101+
96102
# print(f"knowledge_id: {knowledge_id}")
97103
# print(f"document_id: {document_id}")
98104
# print(f"paragraph_id: {paragraph_id}")

apps/knowledge/task/sync.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616

1717
from common.utils.fork import ForkManage, Fork
1818
from ops import celery_app
19-
from .handler import get_save_handler, get_sync_web_document_handler, get_sync_handler
2019

2120
max_kb_error = logging.getLogger("max_kb_error")
2221
max_kb = logging.getLogger("max_kb")
2322

2423

2524
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_web_knowledge')
2625
def sync_web_knowledge(knowledge_id: str, url: str, selector: str):
26+
from knowledge.task.handler import get_save_handler
27+
2728
try:
2829
max_kb.info(
2930
_('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
@@ -39,6 +40,8 @@ def sync_web_knowledge(knowledge_id: str, url: str, selector: str):
3940

4041
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_replace_web_knowledge')
4142
def sync_replace_web_knowledge(knowledge_id: str, url: str, selector: str):
43+
from knowledge.task.handler import get_sync_handler
44+
4245
try:
4346
max_kb.info(
4447
_('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
@@ -53,6 +56,8 @@ def sync_replace_web_knowledge(knowledge_id: str, url: str, selector: str):
5356

5457
@celery_app.task(name='celery:sync_web_document')
5558
def sync_web_document(knowledge_id, source_url_list: List[str], selector: str):
59+
from knowledge.task.handler import get_sync_web_document_handler
60+
5661
handler = get_sync_web_document_handler(knowledge_id)
5762
for source_url in source_url_list:
5863
try:

apps/knowledge/urls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/web', views.WebDocumentView.as_view()),
1515
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/qa', views.QaDocumentView.as_view()),
1616
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/table', views.TableDocumentView.as_view()),
17+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch_hit_handling', views.DocumentView.BatchEditHitHandling.as_view()),
1718
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>', views.DocumentView.Operate.as_view()),
19+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/sync', views.DocumentView.SyncWeb.as_view()),
20+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/refresh', views.DocumentView.Refresh.as_view()),
21+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/cancel_task', views.DocumentView.CancelTask.as_view()),
22+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/cancel_task/batch', views.DocumentView.BatchCancelTask.as_view()),
1823
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
1924
]

apps/knowledge/views/document.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from common.result import result
1111
from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \
1212
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \
13-
WebDocumentCreateAPI
13+
WebDocumentCreateAPI, CancelTaskAPI, BatchCancelTaskAPI, SyncWebAPI, RefreshAPI, BatchEditHitHandlingAPI
1414
from knowledge.api.knowledge import KnowledgeTreeReadAPI
1515
from knowledge.serializers.document import DocumentSerializers
1616

@@ -140,6 +140,97 @@ def post(self, request: Request, workspace_id: str, knowledge_id: str):
140140
'knowledge_id': knowledge_id,
141141
}).parse(split_data))
142142

143+
class BatchEditHitHandling(APIView):
144+
authentication_classes = [TokenAuth]
145+
146+
@extend_schema(
147+
methods=['PUT'],
148+
summary=_('Modify document hit processing methods in batches'),
149+
description=_('Modify document hit processing methods in batches'),
150+
operation_id=_('Modify document hit processing methods in batches'),
151+
request=BatchEditHitHandlingAPI.get_request(),
152+
parameters=BatchEditHitHandlingAPI.get_parameters(),
153+
responses=BatchEditHitHandlingAPI.get_response(),
154+
tags=[_('Knowledge Base/Documentation')]
155+
)
156+
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
157+
def put(self, request: Request, workspace_id: str, knowledge_id: str):
158+
return result.success(DocumentSerializers.Batch(
159+
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
160+
).batch_edit_hit_handling(request.data))
161+
162+
class SyncWeb(APIView):
163+
authentication_classes = [TokenAuth]
164+
165+
@extend_schema(
166+
methods=['PUT'],
167+
description=_('Synchronize web site types'),
168+
summary=_('Synchronize web site types'),
169+
operation_id=_('Synchronize web site types'),
170+
parameters=SyncWebAPI.get_parameters(),
171+
responses=SyncWebAPI.get_response(),
172+
tags=[_('Knowledge Base/Documentation')]
173+
)
174+
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
175+
def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
176+
return result.success(DocumentSerializers.Sync(
177+
data={'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
178+
).sync())
179+
180+
class Refresh(APIView):
181+
authentication_classes = [TokenAuth]
182+
183+
@extend_schema(
184+
methods=['PUT'],
185+
summary=_('Refresh document vector library'),
186+
description=_('Refresh document vector library'),
187+
operation_id=_('Refresh document vector library'),
188+
parameters=RefreshAPI.get_parameters(),
189+
request=RefreshAPI.get_request(),
190+
responses=RefreshAPI.get_response(),
191+
tags=[_('Knowledge Base/Documentation')]
192+
)
193+
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
194+
def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
195+
return result.success(DocumentSerializers.Operate(
196+
data={'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
197+
).refresh(request.data.get('state_list')))
198+
199+
class CancelTask(APIView):
200+
authentication_classes = [TokenAuth]
201+
202+
@extend_schema(
203+
summary=_('Cancel task'),
204+
description=_('Cancel task'),
205+
operation_id=_('Cancel task'),
206+
parameters=CancelTaskAPI.get_parameters(),
207+
request=CancelTaskAPI.get_request(),
208+
responses=CancelTaskAPI.get_response(),
209+
tags=[_('Knowledge Base/Documentation')]
210+
)
211+
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
212+
def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
213+
return result.success(DocumentSerializers.Operate(
214+
data={'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
215+
).cancel(request.data))
216+
217+
class BatchCancelTask(APIView):
218+
authentication_classes = [TokenAuth]
219+
220+
@extend_schema(
221+
summary=_('Cancel tasks in batches'),
222+
description=_('Cancel tasks in batches'),
223+
operation_id=_('Cancel tasks in batches'),
224+
parameters=BatchCancelTaskAPI.get_parameters(),
225+
request=BatchCancelTaskAPI.get_request(),
226+
responses=BatchCancelTaskAPI.get_response(),
227+
tags=[_('Knowledge Base/Documentation')]
228+
)
229+
def put(self, request: Request, workspace_id: str, knowledge_id: str):
230+
return result.success(DocumentSerializers.Batch(data={
231+
'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
232+
).batch_cancel(request.data))
233+
143234
class Batch(APIView):
144235
authentication_classes = [TokenAuth]
145236

0 commit comments

Comments
 (0)