Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def get_answer_list(self) -> List[Answer] | None:
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]

def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
get_node_params=lambda node: node.properties.get('node_data'), salt=None):
# 当前步骤上下文,用于存储当前步骤信息
self.status = 200
self.err_message = ''
Expand All @@ -188,7 +188,8 @@ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
"".join([*sorted(up_node_id_list),
node.id]))),
"utf-8")).hexdigest()
"utf-8")).hexdigest() + (
"__" + str(salt) if salt is not None else '')

def valid_args(self, node_params, flow_params):
flow_params_serializer_class = self.get_flow_params_serializer_class()
Expand Down
167 changes: 167 additions & 0 deletions apps/application/flow/loop_workflow_manage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: workflow_manage.py
@date:2024/1/9 17:40
@desc:
"""
from concurrent.futures import ThreadPoolExecutor
from typing import List

from django.db import close_old_connections
from django.utils.translation import get_language
from langchain_core.prompts import PromptTemplate

from application.flow.common import Workflow
from application.flow.i_step_node import WorkFlowPostHandler, INode
from application.flow.step_node import get_node
from application.flow.workflow_manage import WorkflowManage
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse

executor = ThreadPoolExecutor(max_workers=200)


class NodeResultFuture:
def __init__(self, r, e, status=200):
self.r = r
self.e = e
self.status = status

def result(self):
if self.status == 200:
return self.r
else:
raise self.e


def await_result(result, timeout=1):
try:
result.result(timeout)
return False
except Exception as e:
return True


class NodeChunkManage:

def __init__(self, work_flow):
self.node_chunk_list = []
self.current_node_chunk = None
self.work_flow = work_flow

def add_node_chunk(self, node_chunk):
self.node_chunk_list.append(node_chunk)

def contains(self, node_chunk):
return self.node_chunk_list.__contains__(node_chunk)

def pop(self):
if self.current_node_chunk is None:
try:
current_node_chunk = self.node_chunk_list.pop(0)
self.current_node_chunk = current_node_chunk
except IndexError as e:
pass
if self.current_node_chunk is not None:
try:
chunk = self.current_node_chunk.chunk_list.pop(0)
return chunk
except IndexError as e:
if self.current_node_chunk.is_end():
self.current_node_chunk = None
if self.work_flow.answer_is_not_empty():
chunk = self.work_flow.base_to_response.to_stream_chunk_response(
self.work_flow.params['chat_id'],
self.work_flow.params['chat_record_id'],
'\n\n', False, 0, 0)
self.work_flow.append_answer('\n\n')
return chunk
return self.pop()
return None


class LoopWorkflowManage(WorkflowManage):

def __init__(self, flow: Workflow,
params,
work_flow_post_handler: WorkFlowPostHandler,
parentWorkflowManage,
loop_params,
base_to_response: BaseToResponse = SystemToResponse(), start_node_id=None,
start_node_data=None, chat_record=None, child_node=None):
self.parentWorkflowManage = parentWorkflowManage
self.loop_params = loop_params
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
None,
None, start_node_id, start_node_data, chat_record, child_node)

def get_node_cls_by_id(self, node_id, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
for node in self.flow.nodes:
if node.id == node_id:
node_instance = get_node(node.type)(node,
self.params, self, up_node_id_list,
get_node_params,
salt=self.get_index())
return node_instance
return None

def stream(self):
close_old_connections()
language = get_language()
self.run_chain_async(self.start_node, None, language)
return self.await_result()

def get_index(self):
return self.loop_params.get('index')

def get_start_node(self):
start_node_list = [node for node in self.flow.nodes if
['loop-start-node'].__contains__(node.type)]
return start_node_list[0]

def get_reference_field(self, node_id: str, fields: List[str]):
"""
@param node_id: 节点id
@param fields: 字段
@return:
"""
if node_id == 'global':
return self.parentWorkflowManage.get_reference_field(node_id, fields)
elif node_id == 'chat':
return self.parentWorkflowManage.get_reference_field(node_id, fields)
else:
node = self.get_node_by_id(node_id)
if node:
return node.get_reference_field(fields)
return self.parentWorkflowManage.get_reference_field(node_id, fields)

def get_workflow_content(self):
context = {
'global': self.context,
'chat': self.chat_context
}

for node in self.node_context:
context[node.id] = node.context
return context

def reset_prompt(self, prompt: str):
prompt = super().reset_prompt(prompt)
prompt = self.parentWorkflowManage.reset_prompt(prompt)
return prompt

def generate_prompt(self, prompt: str):
"""
格式化生成提示词
@param prompt: 提示词信息
@return: 格式化后的提示词
"""

context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()}
prompt = self.reset_prompt(prompt)
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
value = prompt_template.format(context=context)
return value
4 changes: 3 additions & 1 deletion apps/application/flow/step_node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .form_node import *
from .image_generate_step_node import *
from .image_understand_step_node import *
from .loop_node import *
from .loop_start_node import *
from .mcp_node import BaseMcpNode
from .question_node import *
from .reranker_node import *
Expand All @@ -30,7 +32,7 @@
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
BaseDocumentExtractNode,
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseLoopNode, BaseLoopStartStepNode]


def get_node(node_type):
Expand Down
9 changes: 9 additions & 0 deletions apps/application/flow/step_node/loop_node/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: __init__.py
@date:2025/3/11 18:24
@desc:
"""
from .impl import *
56 changes: 56 additions & 0 deletions apps/application/flow/step_node/loop_node/i_loop_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: i_loop_node.py
@date:2025/3/11 18:19
@desc:
"""
from typing import Type

from django.utils.translation import gettext_lazy as _
from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.exception.app_exception import AppApiException


class ILoopNodeSerializer(serializers.Serializer):
loop_type = serializers.CharField(required=True, label=_("loop_type"))
array = serializers.ListField(required=False, allow_null=True,
label=_("array"))
number = serializers.IntegerField(required=False, allow_null=True,
label=_("number"))
loop_body = serializers.DictField(required=True, label="循环体")

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
loop_type = self.data.get('loop_type')
if loop_type == 'ARRAY':
array = self.data.get('array')
if array is None or len(array) == 0:
message = _('{field}, this field is required.', field='array')
raise AppApiException(500, message)
elif loop_type == 'NUMBER':
number = self.data.get('number')
if number is None:
message = _('{field}, this field is required.', field='number')
raise AppApiException(500, message)


class ILoopNode(INode):
type = 'loop-node'

def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ILoopNodeSerializer

def _run(self):
array = self.node_params_serializer.data.get('array')
if self.node_params_serializer.data.get('loop_type') == 'ARRAY':
array = self.workflow_manage.get_reference_field(
array[0],
array[1:])
return self.execute(**{**self.node_params_serializer.data, "array": array}, **self.flow_params_serializer.data)

def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult:
pass
9 changes: 9 additions & 0 deletions apps/application/flow/step_node/loop_node/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: __init__.py.py
@date:2025/3/11 18:24
@desc:
"""
from .base_loop_node import BaseLoopNode
Loading
Loading