diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index f60f208c226..e8caf162d73 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -123,6 +123,8 @@ def get_loop_workflow_node(node_list): def get_workflow_state(workflow): + if workflow.is_the_task_interrupted(): + return State.REVOKED details = workflow.get_runtime_details() node_list = details.values() all_node = [*node_list, *get_loop_workflow_node(node_list)] diff --git a/apps/application/flow/knowledge_workflow_manage.py b/apps/application/flow/knowledge_workflow_manage.py index ab739932b74..1a696675cbc 100644 --- a/apps/application/flow/knowledge_workflow_manage.py +++ b/apps/application/flow/knowledge_workflow_manage.py @@ -30,10 +30,10 @@ def __init__(self, flow: Workflow, work_flow_post_handler: WorkFlowPostHandler, base_to_response: BaseToResponse = SystemToResponse(), start_node_id=None, - start_node_data=None, chat_record=None, child_node=None): + start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False): super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None, None, - None, None, start_node_id, start_node_data, chat_record, child_node) + None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted) def get_params_serializer_class(self): return KnowledgeFlowParamsSerializer @@ -91,6 +91,9 @@ def hand_node_result(self, current_node, node_result_future): list(result) if current_node.status == 500: return None + if self.is_the_task_interrupted(): + current_node.status = 201 + return None return current_result except Exception as e: traceback.print_exc() diff --git a/apps/application/flow/loop_workflow_manage.py b/apps/application/flow/loop_workflow_manage.py index bf01e7606db..27c84f4dc66 100644 --- a/apps/application/flow/loop_workflow_manage.py +++ b/apps/application/flow/loop_workflow_manage.py @@ -92,14 +92,14 @@ def __init__(self, flow: Workflow, get_loop_context, base_to_response: BaseToResponse = SystemToResponse(), start_node_id=None, - start_node_data=None, chat_record=None, child_node=None): + start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False): self.parentWorkflowManage = parentWorkflowManage self.loop_params = loop_params self.get_loop_context = get_loop_context self.loop_field_list = [] super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None, None, - None, None, start_node_id, start_node_data, chat_record, child_node) + None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted) def get_node_cls_by_id(self, node_id, up_node_id_list=None, get_node_params=lambda node: node.properties.get('node_data')): diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 85d046b2e59..6e0b32df631 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -63,6 +63,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo response_reasoning_content = False for chunk in response: + if workflow.is_the_task_interrupted(): + break reasoning_chunk = reasoning.get_reasoning_content(chunk) content_chunk = reasoning_chunk.get('content') if 'reasoning_content' in chunk.additional_kwargs: @@ -110,7 +112,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor if 'reasoning_content' in meta: reasoning_content = (meta.get('reasoning_content', '') or '') else: - reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '') + reasoning_content = (reasoning_result.get('reasoning_content') or '') + ( + reasoning_result_end.get('reasoning_content') or '') _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content) diff --git a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py index 2b17851a84d..cd75272e264 100644 --- a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py +++ b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py @@ -268,7 +268,8 @@ def workflow_manage_new_instance(loop_data, global_data, start_node_id=None, start_node_id=start_node_id, start_node_data=start_node_data, chat_record=chat_record, - child_node=child_node + child_node=child_node, + is_the_task_interrupted=self.workflow_manage.is_the_task_interrupted ) return workflow_manage diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 7f23391ff5e..839794d54c4 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -97,7 +97,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH video_list=None, other_list=None, start_node_id=None, - start_node_data=None, chat_record=None, child_node=None): + start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False): if form_data is None: form_data = {} if image_list is None: @@ -138,6 +138,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH self.global_field_list = [] self.chat_field_list = [] self.init_fields() + self.is_the_task_interrupted = is_the_task_interrupted if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) else: diff --git a/apps/common/constants/cache_version.py b/apps/common/constants/cache_version.py index 6664acb5623..1b202dc8e07 100644 --- a/apps/common/constants/cache_version.py +++ b/apps/common/constants/cache_version.py @@ -26,7 +26,7 @@ class Cache_Version(Enum): SYSTEM = "SYSTEM", lambda key: key # 应用对接三方应用的缓存 APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key - + KNOWLEDGE_WORKFLOW_INTERRUPTED = "KNOWLEDGE_WORKFLOW_INTERRUPTED", lambda action_id: action_id # 对话 CHAT = "CHAT", lambda key: key diff --git a/apps/knowledge/serializers/knowledge_workflow.py b/apps/knowledge/serializers/knowledge_workflow.py index 1138bfcf402..245869b085f 100644 --- a/apps/knowledge/serializers/knowledge_workflow.py +++ b/apps/knowledge/serializers/knowledge_workflow.py @@ -15,6 +15,7 @@ from application.flow.knowledge_workflow_manage import KnowledgeWorkflowManage from application.flow.step_node import get_node from application.serializers.application import get_mcp_tools +from common.constants.cache_version import Cache_Version from common.db.search import page_search from common.exception.app_exception import AppApiException from common.utils.rsa_util import rsa_long_decrypt @@ -22,7 +23,7 @@ from knowledge.models import KnowledgeScope, Knowledge, KnowledgeType, KnowledgeWorkflow, KnowledgeWorkflowVersion from knowledge.models.knowledge_action import KnowledgeAction, State from knowledge.serializers.knowledge import KnowledgeModelSerializer -from maxkb.const import CONFIG +from django.core.cache import cache from system_manage.models import AuthTargetType from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer from tools.models import Tool @@ -52,7 +53,11 @@ class KnowledgeWorkflowActionSerializer(serializers.Serializer): knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) def get_query_set(self, instance: Dict): - query_set = QuerySet(KnowledgeAction).filter(knowledge_id=self.data.get('knowledge_id')).values('id','knowledge_id',"state",'meta','run_time',"create_time") + query_set = QuerySet(KnowledgeAction).filter(knowledge_id=self.data.get('knowledge_id')).values('id', + 'knowledge_id', + "state", 'meta', + 'run_time', + "create_time") if instance.get("user_name"): query_set = query_set.filter(meta__user_name__icontains=instance.get('user_name')) if instance.get('state'): @@ -73,7 +78,8 @@ def page(self, current_page, page_size, instance: Dict, is_valid=True): KnowledgeWorkflowActionListQuerySerializer(data=instance).is_valid(raise_exception=True) return page_search(current_page, page_size, self.get_query_set(instance), lambda a: {'id': a.get("id"), 'knowledge_id': a.get("knowledge_id"), 'state': a.get("state"), - 'meta': a.get("meta"), 'run_time': a.get("run_time"), 'create_time': a.get("create_time")}) + 'meta': a.get("meta"), 'run_time': a.get("run_time"), + 'create_time': a.get("create_time")}) def action(self, instance: Dict, user, with_valid=True): if with_valid: @@ -91,7 +97,10 @@ def action(self, instance: Dict, user, with_valid=True): {'knowledge_id': self.data.get("knowledge_id"), 'knowledge_action_id': knowledge_action_id, 'stream': True, 'workspace_id': self.data.get("workspace_id"), **instance}, - KnowledgeWorkflowPostHandler(None, knowledge_action_id)) + KnowledgeWorkflowPostHandler(None, knowledge_action_id), + is_the_task_interrupted=lambda: cache.get( + Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_key(action_id=knowledge_action_id), + version=Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_version()) or False) work_flow_manage.run() return {'id': knowledge_action_id, 'knowledge_id': self.data.get("knowledge_id"), 'state': State.STARTED, 'details': {}, 'meta': meta} @@ -135,6 +144,15 @@ def one(self, is_valid=True): 'details': knowledge_action.details, 'meta': knowledge_action.meta} + def cancel(self, is_valid=True): + if is_valid: + self.is_valid(raise_exception=True) + knowledge_action_id = self.data.get("id") + cache.set(Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_key(action_id=knowledge_action_id), True, + version=Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_version()) + QuerySet(KnowledgeAction).filter(id=knowledge_action_id).update(state=State.REVOKE) + return True + class KnowledgeWorkflowSerializer(serializers.Serializer): class Datasource(serializers.Serializer): diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 3e27fbda6b5..9f3007306bb 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -76,6 +76,7 @@ path('workspace//knowledge//action//', views.KnowledgeWorkflowActionView.Page.as_view()), path('workspace//knowledge//upload_document', views.KnowledgeWorkflowUploadDocumentView.as_view()), path('workspace//knowledge//action/', views.KnowledgeWorkflowActionView.Operate.as_view()), + path('workspace//knowledge//action//cancel', views.KnowledgeWorkflowActionView.Cancel.as_view()), path('workspace//knowledge//mcp_tools', views.McpServers.as_view()), path('workspace//knowledge//knowledge_version', views.KnowledgeWorkflowVersionView.as_view()), path('workspace//knowledge//knowledge_version//', views.KnowledgeWorkflowVersionView.Page.as_view()), diff --git a/apps/knowledge/views/knowledge_workflow.py b/apps/knowledge/views/knowledge_workflow.py index 074218e9d4d..2eb0175fe28 100644 --- a/apps/knowledge/views/knowledge_workflow.py +++ b/apps/knowledge/views/knowledge_workflow.py @@ -168,6 +168,33 @@ def get(self, request, workspace_id: str, knowledge_id: str, knowledge_action_id data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'id': knowledge_action_id}) .one()) + class Cancel(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['POST'], + description=_('Cancel knowledge workflow action'), + summary=_('Cancel knowledge workflow action'), + operation_id=_('Cancel knowledge workflow action'), # type: ignore + parameters=KnowledgeWorkflowActionApi.get_parameters(), + responses=DefaultResultSerializer(), + tags=[_('Knowledge Base')] # type: ignore + ) + @has_permissions( + PermissionConstants.KNOWLEDGE_WORKFLOW_EDIT.get_workspace_knowledge_permission(), + PermissionConstants.KNOWLEDGE_WORKFLOW_EDIT.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + ViewPermission( + [RoleConstants.USER.get_workspace_role()], + [PermissionConstants.KNOWLEDGE.get_workspace_knowledge_permission()], + CompareConstants.AND + ), + ) + def post(self, request, workspace_id: str, knowledge_id: str, knowledge_action_id: str): + return result.success(KnowledgeWorkflowActionSerializer.Operate( + data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'id': knowledge_action_id}) + .cancel()) + class KnowledgeWorkflowView(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/knowledge/knowledge.ts b/ui/src/api/knowledge/knowledge.ts index e3ed013e530..06f151526cc 100644 --- a/ui/src/api/knowledge/knowledge.ts +++ b/ui/src/api/knowledge/knowledge.ts @@ -433,7 +433,13 @@ const getWorkflowAction: ( ) => Promise> = (knowledge_id: string, knowledge_action_id, loading) => { return get(`${prefix.value}/${knowledge_id}/action/${knowledge_action_id}`, {}, loading) } - +const cancelWorkflowAction: ( + knowledge_id: string, + knowledge_action_id: string, + loading?: Ref, +) => Promise> = (knowledge_id: string, knowledge_action_id, loading) => { + return post(`${prefix.value}/${knowledge_id}/action/${knowledge_action_id}/cancel`, {}, loading) +} /** * mcp 节点 */ @@ -480,4 +486,5 @@ export default { putKnowledgeWorkflow, workflowUpload, getWorkflowActionPage, + cancelWorkflowAction, } diff --git a/ui/src/views/knowledge-workflow/component/execution-record/ExecutionDetailDrawer.vue b/ui/src/views/knowledge-workflow/component/execution-record/ExecutionDetailDrawer.vue index 9b2e9e36091..ee8b7995781 100644 --- a/ui/src/views/knowledge-workflow/component/execution-record/ExecutionDetailDrawer.vue +++ b/ui/src/views/knowledge-workflow/component/execution-record/ExecutionDetailDrawer.vue @@ -47,6 +47,20 @@ {{ $t('common.status.fail') }} + + + {{ $t('common.status.REVOKED', '已取消') }} + + + + {{ $t('views.document.fileStatus.REVOKE', '取消中') }} + {{ $t('common.status.padding') }} diff --git a/ui/src/views/knowledge-workflow/component/execution-record/ExecutionRecordDrawer.vue b/ui/src/views/knowledge-workflow/component/execution-record/ExecutionRecordDrawer.vue index d4685240ef8..e236fcdbf86 100644 --- a/ui/src/views/knowledge-workflow/component/execution-record/ExecutionRecordDrawer.vue +++ b/ui/src/views/knowledge-workflow/component/execution-record/ExecutionRecordDrawer.vue @@ -64,6 +64,14 @@ {{ $t('common.status.fail') }} + + + {{ $t('common.status.REVOKED', '已取消') }} + + + + {{ $t('views.document.fileStatus.REVOKE', '取消中') }} + {{ $t('common.status.padding') }} @@ -87,11 +95,22 @@ @@ -157,6 +176,11 @@ const toDetails = (row: any) => { ExecutionDetailDrawerRef.value?.open() } +const cancel = (row: any) => { + loadSharedApi({ type: 'knowledge', systemType: apiType.value }) + .cancelWorkflowAction(active_knowledge_id.value, row.id, loading) + .then((ok: any) => {}) +} const changeFilterHandle = () => { query.value = { user_name: '', status: '' } }