Skip to content

Commit 232dae1

Browse files
committed
feat: loop node Unfinished
1 parent 4cb3912 commit 232dae1

File tree

32 files changed

+1563
-84
lines changed

32 files changed

+1563
-84
lines changed

apps/application/flow/i_step_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def get_answer_list(self) -> List[Answer] | None:
168168
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
169169

170170
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
171-
get_node_params=lambda node: node.properties.get('node_data')):
171+
get_node_params=lambda node: node.properties.get('node_data'), salt=None):
172172
# 当前步骤上下文,用于存储当前步骤信息
173173
self.status = 200
174174
self.err_message = ''
@@ -188,7 +188,8 @@ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
188188
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
189189
"".join([*sorted(up_node_id_list),
190190
node.id]))),
191-
"utf-8")).hexdigest()
191+
"utf-8")).hexdigest() + (
192+
"__" + str(salt) if salt is not None else '')
192193

193194
def valid_args(self, node_params, flow_params):
194195
flow_params_serializer_class = self.get_flow_params_serializer_class()
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: workflow_manage.py
6+
@date:2024/1/9 17:40
7+
@desc:
8+
"""
9+
from concurrent.futures import ThreadPoolExecutor
10+
from typing import List
11+
12+
from django.db import close_old_connections
13+
from django.utils.translation import get_language
14+
from langchain_core.prompts import PromptTemplate
15+
16+
from application.flow.common import Workflow
17+
from application.flow.i_step_node import WorkFlowPostHandler, INode
18+
from application.flow.step_node import get_node
19+
from application.flow.workflow_manage import WorkflowManage
20+
from common.handle.base_to_response import BaseToResponse
21+
from common.handle.impl.response.system_to_response import SystemToResponse
22+
23+
executor = ThreadPoolExecutor(max_workers=200)
24+
25+
26+
class NodeResultFuture:
27+
def __init__(self, r, e, status=200):
28+
self.r = r
29+
self.e = e
30+
self.status = status
31+
32+
def result(self):
33+
if self.status == 200:
34+
return self.r
35+
else:
36+
raise self.e
37+
38+
39+
def await_result(result, timeout=1):
40+
try:
41+
result.result(timeout)
42+
return False
43+
except Exception as e:
44+
return True
45+
46+
47+
class NodeChunkManage:
48+
49+
def __init__(self, work_flow):
50+
self.node_chunk_list = []
51+
self.current_node_chunk = None
52+
self.work_flow = work_flow
53+
54+
def add_node_chunk(self, node_chunk):
55+
self.node_chunk_list.append(node_chunk)
56+
57+
def contains(self, node_chunk):
58+
return self.node_chunk_list.__contains__(node_chunk)
59+
60+
def pop(self):
61+
if self.current_node_chunk is None:
62+
try:
63+
current_node_chunk = self.node_chunk_list.pop(0)
64+
self.current_node_chunk = current_node_chunk
65+
except IndexError as e:
66+
pass
67+
if self.current_node_chunk is not None:
68+
try:
69+
chunk = self.current_node_chunk.chunk_list.pop(0)
70+
return chunk
71+
except IndexError as e:
72+
if self.current_node_chunk.is_end():
73+
self.current_node_chunk = None
74+
if self.work_flow.answer_is_not_empty():
75+
chunk = self.work_flow.base_to_response.to_stream_chunk_response(
76+
self.work_flow.params['chat_id'],
77+
self.work_flow.params['chat_record_id'],
78+
'\n\n', False, 0, 0)
79+
self.work_flow.append_answer('\n\n')
80+
return chunk
81+
return self.pop()
82+
return None
83+
84+
85+
class LoopWorkflowManage(WorkflowManage):
86+
87+
def __init__(self, flow: Workflow,
88+
params,
89+
work_flow_post_handler: WorkFlowPostHandler,
90+
parentWorkflowManage,
91+
loop_params,
92+
base_to_response: BaseToResponse = SystemToResponse(), start_node_id=None,
93+
start_node_data=None, chat_record=None, child_node=None):
94+
self.parentWorkflowManage = parentWorkflowManage
95+
self.loop_params = loop_params
96+
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
97+
None,
98+
None, start_node_id, start_node_data, chat_record, child_node)
99+
100+
def get_node_cls_by_id(self, node_id, up_node_id_list=None,
101+
get_node_params=lambda node: node.properties.get('node_data')):
102+
for node in self.flow.nodes:
103+
if node.id == node_id:
104+
node_instance = get_node(node.type)(node,
105+
self.params, self, up_node_id_list,
106+
get_node_params,
107+
salt=self.get_index())
108+
return node_instance
109+
return None
110+
111+
def stream(self):
112+
close_old_connections()
113+
language = get_language()
114+
self.run_chain_async(self.start_node, None, language)
115+
return self.await_result()
116+
117+
def get_index(self):
118+
return self.loop_params.get('index')
119+
120+
def get_start_node(self):
121+
start_node_list = [node for node in self.flow.nodes if
122+
['loop-start-node'].__contains__(node.type)]
123+
return start_node_list[0]
124+
125+
def get_reference_field(self, node_id: str, fields: List[str]):
126+
"""
127+
@param node_id: 节点id
128+
@param fields: 字段
129+
@return:
130+
"""
131+
if node_id == 'global':
132+
return self.parentWorkflowManage.get_reference_field(node_id, fields)
133+
elif node_id == 'chat':
134+
return self.parentWorkflowManage.get_reference_field(node_id, fields)
135+
else:
136+
node = self.get_node_by_id(node_id)
137+
if node:
138+
return node.get_reference_field(fields)
139+
return self.parentWorkflowManage.get_reference_field(node_id, fields)
140+
141+
def get_workflow_content(self):
142+
context = {
143+
'global': self.context,
144+
'chat': self.chat_context
145+
}
146+
147+
for node in self.node_context:
148+
context[node.id] = node.context
149+
return context
150+
151+
def reset_prompt(self, prompt: str):
152+
prompt = super().reset_prompt(prompt)
153+
prompt = self.parentWorkflowManage.reset_prompt(prompt)
154+
return prompt
155+
156+
def generate_prompt(self, prompt: str):
157+
"""
158+
格式化生成提示词
159+
@param prompt: 提示词信息
160+
@return: 格式化后的提示词
161+
"""
162+
163+
context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()}
164+
prompt = self.reset_prompt(prompt)
165+
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
166+
value = prompt_template.format(context=context)
167+
return value

apps/application/flow/step_node/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .form_node import *
1515
from .image_generate_step_node import *
1616
from .image_understand_step_node import *
17+
from .loop_node import *
18+
from .loop_start_node import *
1719
from .mcp_node import BaseMcpNode
1820
from .question_node import *
1921
from .reranker_node import *
@@ -30,7 +32,7 @@
3032
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
3133
BaseDocumentExtractNode,
3234
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
33-
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
35+
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseLoopNode, BaseLoopStartStepNode]
3436

3537

3638
def get_node(node_type):
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py
6+
@date:2025/3/11 18:24
7+
@desc:
8+
"""
9+
from .impl import *
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: i_loop_node.py
6+
@date:2025/3/11 18:19
7+
@desc:
8+
"""
9+
from typing import Type
10+
11+
from django.utils.translation import gettext_lazy as _
12+
from rest_framework import serializers
13+
14+
from application.flow.i_step_node import INode, NodeResult
15+
from common.exception.app_exception import AppApiException
16+
17+
18+
class ILoopNodeSerializer(serializers.Serializer):
19+
loop_type = serializers.CharField(required=True, label=_("loop_type"))
20+
array = serializers.ListField(required=False, allow_null=True,
21+
label=_("array"))
22+
number = serializers.IntegerField(required=False, allow_null=True,
23+
label=_("number"))
24+
loop_body = serializers.DictField(required=True, label="循环体")
25+
26+
def is_valid(self, *, raise_exception=False):
27+
super().is_valid(raise_exception=True)
28+
loop_type = self.data.get('loop_type')
29+
if loop_type == 'ARRAY':
30+
array = self.data.get('array')
31+
if array is None or len(array) == 0:
32+
message = _('{field}, this field is required.', field='array')
33+
raise AppApiException(500, message)
34+
elif loop_type == 'NUMBER':
35+
number = self.data.get('number')
36+
if number is None:
37+
message = _('{field}, this field is required.', field='number')
38+
raise AppApiException(500, message)
39+
40+
41+
class ILoopNode(INode):
42+
type = 'loop-node'
43+
44+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
45+
return ILoopNodeSerializer
46+
47+
def _run(self):
48+
array = self.node_params_serializer.data.get('array')
49+
if self.node_params_serializer.data.get('loop_type') == 'ARRAY':
50+
array = self.workflow_manage.get_reference_field(
51+
array[0],
52+
array[1:])
53+
return self.execute(**{**self.node_params_serializer.data, "array": array}, **self.flow_params_serializer.data)
54+
55+
def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult:
56+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2025/3/11 18:24
7+
@desc:
8+
"""
9+
from .base_loop_node import BaseLoopNode

0 commit comments

Comments
 (0)