Skip to content

Commit 4f1f0fe

Browse files
authored
feat: Knowledge base workflow supports terminating execution (#4536)
1 parent 2233f04 commit 4f1f0fe

File tree

13 files changed

+119
-18
lines changed

13 files changed

+119
-18
lines changed

apps/application/flow/i_step_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def get_loop_workflow_node(node_list):
123123

124124

125125
def get_workflow_state(workflow):
126+
if workflow.is_the_task_interrupted():
127+
return State.REVOKED
126128
details = workflow.get_runtime_details()
127129
node_list = details.values()
128130
all_node = [*node_list, *get_loop_workflow_node(node_list)]

apps/application/flow/knowledge_workflow_manage.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def __init__(self, flow: Workflow,
3030
work_flow_post_handler: WorkFlowPostHandler,
3131
base_to_response: BaseToResponse = SystemToResponse(),
3232
start_node_id=None,
33-
start_node_data=None, chat_record=None, child_node=None):
33+
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
3434
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
3535
None,
36-
None, None, start_node_id, start_node_data, chat_record, child_node)
36+
None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
3737

3838
def get_params_serializer_class(self):
3939
return KnowledgeFlowParamsSerializer
@@ -91,6 +91,9 @@ def hand_node_result(self, current_node, node_result_future):
9191
list(result)
9292
if current_node.status == 500:
9393
return None
94+
if self.is_the_task_interrupted():
95+
current_node.status = 201
96+
return None
9497
return current_result
9598
except Exception as e:
9699
traceback.print_exc()

apps/application/flow/loop_workflow_manage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def __init__(self, flow: Workflow,
9292
get_loop_context,
9393
base_to_response: BaseToResponse = SystemToResponse(),
9494
start_node_id=None,
95-
start_node_data=None, chat_record=None, child_node=None):
95+
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
9696
self.parentWorkflowManage = parentWorkflowManage
9797
self.loop_params = loop_params
9898
self.get_loop_context = get_loop_context
9999
self.loop_field_list = []
100100
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
101101
None,
102-
None, None, start_node_id, start_node_data, chat_record, child_node)
102+
None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
103103

104104
def get_node_cls_by_id(self, node_id, up_node_id_list=None,
105105
get_node_params=lambda node: node.properties.get('node_data')):

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
6363
response_reasoning_content = False
6464

6565
for chunk in response:
66+
if workflow.is_the_task_interrupted():
67+
break
6668
reasoning_chunk = reasoning.get_reasoning_content(chunk)
6769
content_chunk = reasoning_chunk.get('content')
6870
if 'reasoning_content' in chunk.additional_kwargs:
@@ -110,7 +112,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
110112
if 'reasoning_content' in meta:
111113
reasoning_content = (meta.get('reasoning_content', '') or '')
112114
else:
113-
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '')
115+
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (
116+
reasoning_result_end.get('reasoning_content') or '')
114117
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
115118

116119

apps/application/flow/step_node/loop_node/impl/base_loop_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def workflow_manage_new_instance(loop_data, global_data, start_node_id=None,
268268
start_node_id=start_node_id,
269269
start_node_data=start_node_data,
270270
chat_record=chat_record,
271-
child_node=child_node
271+
child_node=child_node,
272+
is_the_task_interrupted=self.workflow_manage.is_the_task_interrupted
272273
)
273274

274275
return workflow_manage

apps/application/flow/workflow_manage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
9797
video_list=None,
9898
other_list=None,
9999
start_node_id=None,
100-
start_node_data=None, chat_record=None, child_node=None):
100+
start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
101101
if form_data is None:
102102
form_data = {}
103103
if image_list is None:
@@ -138,6 +138,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
138138
self.global_field_list = []
139139
self.chat_field_list = []
140140
self.init_fields()
141+
self.is_the_task_interrupted = is_the_task_interrupted
141142
if start_node_id is not None:
142143
self.load_node(chat_record, start_node_id, start_node_data)
143144
else:

apps/common/constants/cache_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Cache_Version(Enum):
2626
SYSTEM = "SYSTEM", lambda key: key
2727
# 应用对接三方应用的缓存
2828
APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key
29-
29+
KNOWLEDGE_WORKFLOW_INTERRUPTED = "KNOWLEDGE_WORKFLOW_INTERRUPTED", lambda action_id: action_id
3030
# 对话
3131
CHAT = "CHAT", lambda key: key
3232

apps/knowledge/serializers/knowledge_workflow.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
from application.flow.knowledge_workflow_manage import KnowledgeWorkflowManage
1616
from application.flow.step_node import get_node
1717
from application.serializers.application import get_mcp_tools
18+
from common.constants.cache_version import Cache_Version
1819
from common.db.search import page_search
1920
from common.exception.app_exception import AppApiException
2021
from common.utils.rsa_util import rsa_long_decrypt
2122
from common.utils.tool_code import ToolExecutor
2223
from knowledge.models import KnowledgeScope, Knowledge, KnowledgeType, KnowledgeWorkflow, KnowledgeWorkflowVersion
2324
from knowledge.models.knowledge_action import KnowledgeAction, State
2425
from knowledge.serializers.knowledge import KnowledgeModelSerializer
25-
from maxkb.const import CONFIG
26+
from django.core.cache import cache
2627
from system_manage.models import AuthTargetType
2728
from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
2829
from tools.models import Tool
@@ -52,7 +53,11 @@ class KnowledgeWorkflowActionSerializer(serializers.Serializer):
5253
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
5354

5455
def get_query_set(self, instance: Dict):
55-
query_set = QuerySet(KnowledgeAction).filter(knowledge_id=self.data.get('knowledge_id')).values('id','knowledge_id',"state",'meta','run_time',"create_time")
56+
query_set = QuerySet(KnowledgeAction).filter(knowledge_id=self.data.get('knowledge_id')).values('id',
57+
'knowledge_id',
58+
"state", 'meta',
59+
'run_time',
60+
"create_time")
5661
if instance.get("user_name"):
5762
query_set = query_set.filter(meta__user_name__icontains=instance.get('user_name'))
5863
if instance.get('state'):
@@ -73,7 +78,8 @@ def page(self, current_page, page_size, instance: Dict, is_valid=True):
7378
KnowledgeWorkflowActionListQuerySerializer(data=instance).is_valid(raise_exception=True)
7479
return page_search(current_page, page_size, self.get_query_set(instance),
7580
lambda a: {'id': a.get("id"), 'knowledge_id': a.get("knowledge_id"), 'state': a.get("state"),
76-
'meta': a.get("meta"), 'run_time': a.get("run_time"), 'create_time': a.get("create_time")})
81+
'meta': a.get("meta"), 'run_time': a.get("run_time"),
82+
'create_time': a.get("create_time")})
7783

7884
def action(self, instance: Dict, user, with_valid=True):
7985
if with_valid:
@@ -91,7 +97,10 @@ def action(self, instance: Dict, user, with_valid=True):
9197
{'knowledge_id': self.data.get("knowledge_id"), 'knowledge_action_id': knowledge_action_id, 'stream': True,
9298
'workspace_id': self.data.get("workspace_id"),
9399
**instance},
94-
KnowledgeWorkflowPostHandler(None, knowledge_action_id))
100+
KnowledgeWorkflowPostHandler(None, knowledge_action_id),
101+
is_the_task_interrupted=lambda: cache.get(
102+
Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_key(action_id=knowledge_action_id),
103+
version=Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_version()) or False)
95104
work_flow_manage.run()
96105
return {'id': knowledge_action_id, 'knowledge_id': self.data.get("knowledge_id"), 'state': State.STARTED,
97106
'details': {}, 'meta': meta}
@@ -135,6 +144,15 @@ def one(self, is_valid=True):
135144
'details': knowledge_action.details,
136145
'meta': knowledge_action.meta}
137146

147+
def cancel(self, is_valid=True):
148+
if is_valid:
149+
self.is_valid(raise_exception=True)
150+
knowledge_action_id = self.data.get("id")
151+
cache.set(Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_key(action_id=knowledge_action_id), True,
152+
version=Cache_Version.KNOWLEDGE_WORKFLOW_INTERRUPTED.get_version())
153+
QuerySet(KnowledgeAction).filter(id=knowledge_action_id).update(state=State.REVOKE)
154+
return True
155+
138156

139157
class KnowledgeWorkflowSerializer(serializers.Serializer):
140158
class Datasource(serializers.Serializer):

apps/knowledge/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/action/<int:current_page>/<int:page_size>', views.KnowledgeWorkflowActionView.Page.as_view()),
7777
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/upload_document', views.KnowledgeWorkflowUploadDocumentView.as_view()),
7878
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/action/<str:knowledge_action_id>', views.KnowledgeWorkflowActionView.Operate.as_view()),
79+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/action/<str:knowledge_action_id>/cancel', views.KnowledgeWorkflowActionView.Cancel.as_view()),
7980
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/mcp_tools', views.McpServers.as_view()),
8081
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/knowledge_version', views.KnowledgeWorkflowVersionView.as_view()),
8182
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/knowledge_version/<int:current_page>/<int:page_size>', views.KnowledgeWorkflowVersionView.Page.as_view()),

apps/knowledge/views/knowledge_workflow.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,33 @@ def get(self, request, workspace_id: str, knowledge_id: str, knowledge_action_id
168168
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'id': knowledge_action_id})
169169
.one())
170170

171+
class Cancel(APIView):
172+
authentication_classes = [TokenAuth]
173+
174+
@extend_schema(
175+
methods=['POST'],
176+
description=_('Cancel knowledge workflow action'),
177+
summary=_('Cancel knowledge workflow action'),
178+
operation_id=_('Cancel knowledge workflow action'), # type: ignore
179+
parameters=KnowledgeWorkflowActionApi.get_parameters(),
180+
responses=DefaultResultSerializer(),
181+
tags=[_('Knowledge Base')] # type: ignore
182+
)
183+
@has_permissions(
184+
PermissionConstants.KNOWLEDGE_WORKFLOW_EDIT.get_workspace_knowledge_permission(),
185+
PermissionConstants.KNOWLEDGE_WORKFLOW_EDIT.get_workspace_permission_workspace_manage_role(),
186+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(),
187+
ViewPermission(
188+
[RoleConstants.USER.get_workspace_role()],
189+
[PermissionConstants.KNOWLEDGE.get_workspace_knowledge_permission()],
190+
CompareConstants.AND
191+
),
192+
)
193+
def post(self, request, workspace_id: str, knowledge_id: str, knowledge_action_id: str):
194+
return result.success(KnowledgeWorkflowActionSerializer.Operate(
195+
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'id': knowledge_action_id})
196+
.cancel())
197+
171198

172199
class KnowledgeWorkflowView(APIView):
173200
authentication_classes = [TokenAuth]

0 commit comments

Comments
 (0)