Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare.compare import Compare


class ContainCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class EqualCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class GECompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class GTCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare import Compare
from application.flow.compare import Compare


class IsNotNullCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare import Compare
from application.flow.compare import Compare


class IsNotTrueCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare import Compare
from application.flow.compare import Compare


class IsNullCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare import Compare
from application.flow.compare import Compare


class IsTrueCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class LECompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class LenEqualCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class LenGECompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class LenGTCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class LenLECompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class LenLTCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class LTCompare(Compare):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
from typing import List

from application.flow.step_node.condition_node.compare.compare import Compare
from application.flow.compare import Compare


class NotContainCompare(Compare):
Expand Down
5 changes: 3 additions & 2 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def get_answer_list(self) -> List[Answer] | None:
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]

def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
get_node_params=lambda node: node.properties.get('node_data'), salt=None):
# 当前步骤上下文,用于存储当前步骤信息
self.status = 200
self.err_message = ''
Expand All @@ -188,7 +188,8 @@ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
"".join([*sorted(up_node_id_list),
node.id]))),
"utf-8")).hexdigest()
"utf-8")).hexdigest() + (
"__" + str(salt) if salt is not None else '')

def valid_args(self, node_params, flow_params):
flow_params_serializer_class = self.get_flow_params_serializer_class()
Expand Down
193 changes: 193 additions & 0 deletions apps/application/flow/loop_workflow_manage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: workflow_manage.py
@date:2024/1/9 17:40
@desc:
"""
from concurrent.futures import ThreadPoolExecutor
from typing import List

from django.db import close_old_connections
from django.utils.translation import get_language
from langchain_core.prompts import PromptTemplate

from application.flow.common import Workflow
from application.flow.i_step_node import WorkFlowPostHandler, INode
from application.flow.step_node import get_node
from application.flow.workflow_manage import WorkflowManage
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse

executor = ThreadPoolExecutor(max_workers=200)


class NodeResultFuture:
def __init__(self, r, e, status=200):
self.r = r
self.e = e
self.status = status

def result(self):
if self.status == 200:
return self.r
else:
raise self.e


def await_result(result, timeout=1):
try:
result.result(timeout)
return False
except Exception as e:
return True


class NodeChunkManage:

def __init__(self, work_flow):
self.node_chunk_list = []
self.current_node_chunk = None
self.work_flow = work_flow

def add_node_chunk(self, node_chunk):
self.node_chunk_list.append(node_chunk)

def contains(self, node_chunk):
return self.node_chunk_list.__contains__(node_chunk)

def pop(self):
if self.current_node_chunk is None:
try:
current_node_chunk = self.node_chunk_list.pop(0)
self.current_node_chunk = current_node_chunk
except IndexError as e:
pass
if self.current_node_chunk is not None:
try:
chunk = self.current_node_chunk.chunk_list.pop(0)
return chunk
except IndexError as e:
if self.current_node_chunk.is_end():
self.current_node_chunk = None
if self.work_flow.answer_is_not_empty():
chunk = self.work_flow.base_to_response.to_stream_chunk_response(
self.work_flow.params['chat_id'],
self.work_flow.params['chat_record_id'],
'\n\n', False, 0, 0)
self.work_flow.append_answer('\n\n')
return chunk
return self.pop()
return None


class LoopWorkflowManage(WorkflowManage):

def __init__(self, flow: Workflow,
params,
work_flow_post_handler: WorkFlowPostHandler,
parentWorkflowManage,
loop_params,
get_loop_context,
base_to_response: BaseToResponse = SystemToResponse(),
start_node_id=None,
start_node_data=None, chat_record=None, child_node=None):
self.parentWorkflowManage = parentWorkflowManage
self.loop_params = loop_params
self.get_loop_context = get_loop_context
self.loop_field_list = []
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
None,
None, start_node_id, start_node_data, chat_record, child_node)

def get_node_cls_by_id(self, node_id, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
for node in self.flow.nodes:
if node.id == node_id:
node_instance = get_node(node.type)(node,
self.params, self, up_node_id_list,
get_node_params,
salt=self.get_index())
return node_instance
return None

def stream(self):
close_old_connections()
language = get_language()
self.run_chain_async(self.start_node, None, language)
return self.await_result()

def get_index(self):
return self.loop_params.get('index')

def get_start_node(self):
start_node_list = [node for node in self.flow.nodes if
['loop-start-node'].__contains__(node.type)]
return start_node_list[0]

def get_reference_field(self, node_id: str, fields: List[str]):
"""
@param node_id: 节点id
@param fields: 字段
@return:
"""
if node_id == 'global':
return self.parentWorkflowManage.get_reference_field(node_id, fields)
elif node_id == 'chat':
return self.parentWorkflowManage.get_reference_field(node_id, fields)
elif node_id == 'loop':
loop_context = self.get_loop_context()
return INode.get_field(loop_context, fields)
else:
node = self.get_node_by_id(node_id)
if node:
return node.get_reference_field(fields)
return self.parentWorkflowManage.get_reference_field(node_id, fields)

def get_workflow_content(self):
context = {
'global': self.context,
'chat': self.chat_context,
'loop': self.get_loop_context(),
}

for node in self.node_context:
context[node.id] = node.context
return context

def init_fields(self):
super().init_fields()
loop_field_list = []
loop_start_node = self.flow.get_node('loop-start-node')
loop_input_field_list = loop_start_node.properties.get('loop_input_field_list')
node_name = loop_start_node.properties.get('stepName')
node_id = loop_start_node.id
if loop_input_field_list is not None:
for f in loop_input_field_list:
loop_field_list.append(
{'label': f.get('label'), 'value': f.get('field'), 'node_id': node_id, 'node_name': node_name})
self.loop_field_list = loop_field_list

def reset_prompt(self, prompt: str):
prompt = super().reset_prompt(prompt)
for field in self.loop_field_list:
chatLabel = f"loop.{field.get('value')}"
chatValue = f"context.get('loop').get('{field.get('value', '')}','')"
prompt = prompt.replace(chatLabel, chatValue)

prompt = self.parentWorkflowManage.reset_prompt(prompt)
return prompt

def generate_prompt(self, prompt: str):
"""
格式化生成提示词
@param prompt: 提示词信息
@return: 格式化后的提示词
"""

context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()}
prompt = self.reset_prompt(prompt)
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
value = prompt_template.format(context=context)
return value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review:

  • The NodeResultFuture class has unnecessary logic for handling execution status and exceptions separately. It can be simplified to directly handle both using Python's exception mechanism.
class NodeResultFuture:
    def __init__(self, result, error):
        self.result = result
        self.error = error

    def unwrap(self):
        if isinstance(self.error, ExecutionError):  # Assuming ExecutionError is a custom exception type
            raise self.error
        return self.result
  • In await_result, consider adding logging statements to track when requests are still open after the specified timeout.
import logging

logger = logging.getLogger(__name__)

def await_result(result, timeout=1):
    try:
        result.result(timeout)
        logger.info(f"Request completed within {timeout} seconds.")
        return False
    except Exception as e:
        logger.warning(f"Request did not complete within {timeout} seconds. Error: {e}")
        return True
  • For managing asynchronous tasks in a more straightforward manner, use asyncio.
import asyncio

executor = ThreadPoolExecutor(max_workers=200)

async def run_task(func, *args, **kwargs):
    with executor.submit(func, *args, **kwargs) as future:
        loop = asyncio.get_running_loop()
        done, pending = await asyncio.wait({future}, timeout=1)
        
        if done:
            yield future.result().unwrap()
        elif pending:
            print("Timeout occurred")
  • Consider cleaning up unused imports at the top of your file.
from datetime import timedelta
import json
from typing import *
from urllib.parse import urlencode, urlparse
import os

Overall, this code provides basic structure for an AI-based workflow management system with threading support. Adding asynchronous capabilities could improve performance significantly, especially when dealing with time-consuming operations or long-running workflows.

10 changes: 8 additions & 2 deletions apps/application/flow/step_node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from .image_generate_step_node import *
from .image_to_video_step_node import BaseImageToVideoNode
from .image_understand_step_node import *
from .intent_node import *
from .loop_break_node import BaseLoopBreakNode
from .loop_continue_node import BaseLoopContinueNode
from .loop_node import *
from .loop_start_node import *
from .mcp_node import BaseMcpNode
from .question_node import *
from .reranker_node import *
Expand All @@ -26,15 +31,16 @@
from .tool_lib_node import *
from .tool_node import *
from .variable_assign_node import BaseVariableAssignNode
from .intent_node import *

node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseQuestionNode,
BaseConditionNode, BaseReplyNode,
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
BaseDocumentExtractNode,
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseTextToVideoNode, BaseImageToVideoNode,
BaseIntentNode]
BaseIntentNode, BaseLoopNode, BaseLoopStartStepNode,
BaseLoopContinueNode,
BaseLoopBreakNode]


def get_node(node_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import List

from application.flow.i_step_node import NodeResult
from application.flow.step_node.condition_node.compare import compare_handle_list
from application.flow.compare import compare_handle_list
from application.flow.step_node.condition_node.i_condition_node import IConditionNode


Expand Down
Loading
Loading