Skip to content

Commit 5837650

Browse files
committed
feat: loopNode
1 parent e11c550 commit 5837650

File tree

24 files changed

+1006
-28
lines changed

24 files changed

+1006
-28
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .function_node import *
1616
from .question_node import *
1717
from .reranker_node import *
18-
18+
from .loop_node import *
1919
from .document_extract_node import *
2020
from .image_understand_step_node import *
2121
from .image_generate_step_node import *
@@ -31,7 +31,7 @@
3131
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
3232
BaseDocumentExtractNode,
3333
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
34-
BaseImageGenerateNode, BaseVariableAssignNode]
34+
BaseImageGenerateNode, BaseVariableAssignNode, BaseLoopNode]
3535

3636

3737
def get_node(node_type):
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py
6+
@date:2025/3/11 18:24
7+
@desc:
8+
"""
9+
from .impl import *
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: i_loop_node.py
6+
@date:2025/3/11 18:19
7+
@desc:
8+
"""
9+
from typing import Type
10+
11+
from application.flow.i_step_node import INode, NodeResult
12+
from rest_framework import serializers
13+
14+
from common.exception.app_exception import AppApiException
15+
from common.util.field_message import ErrMessage
16+
from django.utils.translation import gettext_lazy as _
17+
18+
19+
class ILoopNodeSerializer(serializers.Serializer):
20+
loop_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("loop_type")))
21+
array = serializers.ListField(required=False, allow_null=True,
22+
error_messages=ErrMessage.char(_("array")))
23+
number = serializers.IntegerField(required=False, allow_null=True,
24+
error_messages=ErrMessage.char(_("number")))
25+
loop_body = serializers.DictField(required=True, error_messages=ErrMessage.char("循环体"))
26+
27+
def is_valid(self, *, raise_exception=False):
28+
super().is_valid(raise_exception=True)
29+
loop_type = self.data.get('loop_type')
30+
if loop_type == 'ARRAY':
31+
array = self.data.get('array')
32+
if array is None or len(array) == 0:
33+
message = _('{field}, this field is required.', field='array')
34+
raise AppApiException(500, message)
35+
elif loop_type == 'NUMBER':
36+
number = self.data.get('number')
37+
if number is None:
38+
message = _('{field}, this field is required.', field='number')
39+
raise AppApiException(500, message)
40+
41+
42+
class ILoopNode(INode):
43+
type = 'loop-node'
44+
45+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
46+
return ILoopNodeSerializer
47+
48+
def _run(self):
49+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
50+
51+
def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult:
52+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2025/3/11 18:24
7+
@desc:
8+
"""
9+
from .base_loop_node import BaseLoopNode
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: base_loop_node.py
6+
@date:2025/3/11 18:24
7+
@desc:
8+
"""
9+
import time
10+
from typing import Dict
11+
12+
from application.flow.i_step_node import NodeResult, WorkFlowPostHandler, INode
13+
from application.flow.step_node.loop_node.i_loop_node import ILoopNode
14+
from application.flow.tools import Reasoning
15+
from common.handle.impl.response.loop_to_response import LoopToResponse
16+
17+
18+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
19+
reasoning_content: str):
20+
node.context['answer'] = answer
21+
node.context['run_time'] = time.time() - node.context['start_time']
22+
node.context['reasoning_content'] = reasoning_content
23+
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
24+
node.answer_text = answer
25+
26+
27+
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
28+
"""
29+
写入上下文数据 (流式)
30+
@param node_variable: 节点数据
31+
@param workflow_variable: 全局数据
32+
@param node: 节点
33+
@param workflow: 工作流管理器
34+
"""
35+
response = node_variable.get('result')
36+
answer = ''
37+
reasoning_content = ''
38+
for chunk in response:
39+
content_chunk = chunk.get('content', '')
40+
reasoning_content_chunk = chunk.get('reasoning_content', '')
41+
reasoning_content += reasoning_content_chunk
42+
answer += content_chunk
43+
yield {'content': content_chunk,
44+
'reasoning_content': reasoning_content_chunk}
45+
46+
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
47+
48+
49+
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
50+
"""
51+
写入上下文数据
52+
@param node_variable: 节点数据
53+
@param workflow_variable: 全局数据
54+
@param node: 节点实例对象
55+
@param workflow: 工作流管理器
56+
"""
57+
response = node_variable.get('result')
58+
model_setting = node.context.get('model_setting',
59+
{'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
60+
'reasoning_content_start': '<think>'})
61+
reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end'))
62+
reasoning_result = reasoning.get_reasoning_content(response)
63+
reasoning_result_end = reasoning.get_end_reasoning_content()
64+
content = reasoning_result.get('content') + reasoning_result_end.get('content')
65+
if 'reasoning_content' in response.response_metadata:
66+
reasoning_content = response.response_metadata.get('reasoning_content', '')
67+
else:
68+
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
69+
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
70+
71+
72+
def loop_number(number, loop_body):
73+
"""
74+
指定次数循环
75+
@return:
76+
"""
77+
pass
78+
79+
80+
def loop_array(array, loop_body):
81+
"""
82+
循环数组
83+
@return:
84+
"""
85+
pass
86+
87+
88+
def loop_loop(loop_body):
89+
"""
90+
无线循环
91+
@return:
92+
"""
93+
pass
94+
95+
96+
class LoopWorkFlowPostHandler(WorkFlowPostHandler):
97+
def handler(self, chat_id,
98+
chat_record_id,
99+
answer,
100+
workflow):
101+
pass
102+
103+
104+
class BaseLoopNode(ILoopNode):
105+
def save_context(self, details, workflow_manage):
106+
self.context['result'] = details.get('result')
107+
self.answer_text = str(details.get('result'))
108+
109+
def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult:
110+
from application.flow.workflow_manage import WorkflowManage, Flow
111+
workflow_manage = WorkflowManage(Flow.new_instance(loop_body), self.workflow_manage.params,
112+
LoopWorkFlowPostHandler(self.workflow_manage.work_flow_post_handler.chat_info
113+
,
114+
self.workflow_manage.work_flow_post_handler.client_id,
115+
self.workflow_manage.work_flow_post_handler.client_type)
116+
, base_to_response=LoopToResponse())
117+
result = workflow_manage.stream()
118+
return NodeResult({"result": result}, {}, _write_context=write_context_stream)
119+
120+
def get_details(self, index: int, **kwargs):
121+
return {
122+
'name': self.node.properties.get('stepName'),
123+
"index": index,
124+
"result": self.context.get('result'),
125+
"params": self.context.get('params'),
126+
'run_time': self.context.get('run_time'),
127+
'type': self.node.type,
128+
'status': self.status,
129+
'err_message': self.err_message
130+
}

apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def execute(self, dataset_id_list, dataset_setting, question,
8888
'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')],
8989
'data': '\n'.join(
9090
[f"{reset_title(paragraph.get('title', ''))}{paragraph.get('content')}" for paragraph in
91-
paragraph_list])[0:dataset_setting.get('max_paragraph_char_number', 5000)],
91+
result])[0:dataset_setting.get('max_paragraph_char_number', 5000)],
9292
'directly_return': '\n'.join(
9393
[paragraph.get('content') for paragraph in
9494
result if

apps/application/flow/step_node/start_node/impl/base_start_node.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ def get_global_variable(node):
3333
class BaseStartStepNode(IStarNode):
3434
def save_context(self, details, workflow_manage):
3535
base_node = self.workflow_manage.get_base_node()
36-
default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', []))
36+
default_global_variable = {}
37+
if base_node is not None:
38+
default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', []))
39+
3740
workflow_variable = {**default_global_variable, **get_global_variable(self)}
3841
self.context['question'] = details.get('question')
3942
self.context['run_time'] = details.get('run_time')
@@ -50,7 +53,9 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
5053

5154
def execute(self, question, **kwargs) -> NodeResult:
5255
base_node = self.workflow_manage.get_base_node()
53-
default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', []))
56+
default_global_variable = {}
57+
if base_node is not None:
58+
default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', []))
5459
workflow_variable = {**default_global_variable, **get_global_variable(self)}
5560
"""
5661
开始节点 初始化全局变量

apps/application/flow/workflow_manage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,12 @@ def get_node_params(n):
338338
node.node_chunk.end()
339339
self.node_context.append(node)
340340

341+
def stream(self):
342+
close_old_connections()
343+
language = get_language()
344+
self.run_chain_async(self.start_node, None, language)
345+
return self.await_result()
346+
341347
def run(self):
342348
close_old_connections()
343349
language = get_language()
@@ -801,6 +807,8 @@ def get_base_node(self):
801807
@return:
802808
"""
803809
base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
810+
if len(base_node_list) == 0:
811+
return None
804812
return base_node_list[0]
805813

806814
def get_node_cls_by_id(self, node_id, up_node_id_list=None,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: LoopToResponse.py
6+
@date:2025/3/12 17:21
7+
@desc:
8+
"""
9+
import json
10+
11+
from common.handle.impl.response.system_to_response import SystemToResponse
12+
13+
14+
class LoopToResponse(SystemToResponse):
15+
16+
def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end,
17+
completion_tokens,
18+
prompt_tokens, other_params: dict = None):
19+
if other_params is None:
20+
other_params = {}
21+
return {'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True,
22+
'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list,
23+
'is_end': is_end,
24+
'usage': {'completion_tokens': completion_tokens,
25+
'prompt_tokens': prompt_tokens,
26+
'total_tokens': completion_tokens + prompt_tokens},
27+
**other_params}

ui/src/enums/workflow.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,7 @@ export enum WorkflowType {
1616
FormNode = 'form-node',
1717
TextToSpeechNode = 'text-to-speech-node',
1818
SpeechToTextNode = 'speech-to-text-node',
19-
ImageGenerateNode = 'image-generate-node'
19+
ImageGenerateNode = 'image-generate-node',
20+
LoopNode = 'loop-node',
21+
LoopBodyNode = 'loop-body-node'
2022
}

0 commit comments

Comments
 (0)