Skip to content

Commit 466175f

Browse files
authored
perf: workflow chat (#3247)
1 parent d9dc3db commit 466175f

File tree

6 files changed

+254
-200
lines changed

6 files changed

+254
-200
lines changed

apps/application/flow/common.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@
77
@desc:
88
"""
99

10+
from typing import List, Dict
11+
12+
from django.db.models import QuerySet
13+
from django.utils.translation import gettext as _
14+
from rest_framework.exceptions import ErrorDetail, ValidationError
15+
16+
from common.exception.app_exception import AppApiException
17+
from common.utils.common import group_by
18+
from models_provider.models import Model
19+
from models_provider.tools import get_model_credential
20+
from tools.models.tool import Tool
21+
22+
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
23+
'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node']
24+
1025

1126
class Answer:
1227
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
@@ -42,3 +57,208 @@ def end(self, chunk=None):
4257

4358
def is_end(self):
4459
return self.status == 200
60+
61+
62+
class Edge:
63+
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
64+
self.id = _id
65+
self.type = _type
66+
self.sourceNodeId = sourceNodeId
67+
self.targetNodeId = targetNodeId
68+
for keyword in keywords:
69+
self.__setattr__(keyword, keywords.get(keyword))
70+
71+
72+
class Node:
73+
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
74+
self.id = _id
75+
self.type = _type
76+
self.x = x
77+
self.y = y
78+
self.properties = properties
79+
for keyword in kwargs:
80+
self.__setattr__(keyword, kwargs.get(keyword))
81+
82+
83+
class EdgeNode:
84+
edge: Edge
85+
node: Node
86+
87+
def __init__(self, edge, node):
88+
self.edge = edge
89+
self.node = node
90+
91+
92+
class Workflow:
93+
"""
94+
节点列表
95+
"""
96+
nodes: List[Node]
97+
"""
98+
线列表
99+
"""
100+
edges: List[Edge]
101+
"""
102+
节点id:node
103+
"""
104+
node_map: Dict[str, Node]
105+
"""
106+
节点id:当前节点id上面的所有节点
107+
"""
108+
up_node_map: Dict[str, List[EdgeNode]]
109+
"""
110+
节点id:当前节点id下面的所有节点
111+
"""
112+
next_node_map: Dict[str, List[EdgeNode]]
113+
114+
def __init__(self, nodes: List[Node], edges: List[Edge]):
115+
self.nodes = nodes
116+
self.edges = edges
117+
self.node_map = {node.id: node for node in nodes}
118+
119+
self.up_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.sourceNodeId)) for
120+
edge in edges] for
121+
key, edges in
122+
group_by(edges, key=lambda edge: edge.targetNodeId).items()}
123+
124+
self.next_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.targetNodeId)) for edge in edges] for
125+
key, edges in
126+
group_by(edges, key=lambda edge: edge.sourceNodeId).items()}
127+
128+
def get_node(self, node_id):
129+
"""
130+
根据node_id 获取节点信息
131+
@param node_id: node_id
132+
@return: 节点信息
133+
"""
134+
return self.node_map.get(node_id)
135+
136+
def get_up_edge_nodes(self, node_id) -> List[EdgeNode]:
137+
"""
138+
根据节点id 获取当前连接前置节点和连线
139+
@param node_id: 节点id
140+
@return: 节点连线列表
141+
"""
142+
return self.up_node_map.get(node_id)
143+
144+
def get_next_edge_nodes(self, node_id) -> List[EdgeNode]:
145+
"""
146+
根据节点id 获取当前连接目标节点和连线
147+
@param node_id: 节点id
148+
@return: 节点连线列表
149+
"""
150+
return self.next_node_map.get(node_id)
151+
152+
def get_up_nodes(self, node_id) -> List[Node]:
153+
"""
154+
根据节点id 获取当前连接前置节点
155+
@param node_id: 节点id
156+
@return: 节点列表
157+
"""
158+
return [en.node for en in self.up_node_map.get(node_id)]
159+
160+
def get_next_nodes(self, node_id) -> List[Node]:
161+
"""
162+
根据节点id 获取当前连接目标节点
163+
@param node_id: 节点id
164+
@return: 节点列表
165+
"""
166+
return [en.node for en in self.next_node_map.get(node_id, [])]
167+
168+
@staticmethod
169+
def new_instance(flow_obj: Dict):
170+
nodes = flow_obj.get('nodes')
171+
edges = flow_obj.get('edges')
172+
nodes = [Node(node.get('id'), node.get('type'), **node)
173+
for node in nodes]
174+
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
175+
return Workflow(nodes, edges)
176+
177+
def get_start_node(self):
178+
start_node_list = [node for node in self.nodes if node.id == 'start-node']
179+
return start_node_list[0]
180+
181+
def get_search_node(self):
182+
return [node for node in self.nodes if node.type == 'search-dataset-node']
183+
184+
def is_valid(self):
185+
"""
186+
校验工作流数据
187+
"""
188+
self.is_valid_model_params()
189+
self.is_valid_start_node()
190+
self.is_valid_base_node()
191+
self.is_valid_work_flow()
192+
193+
@staticmethod
194+
def is_valid_node_params(node: Node):
195+
from application.flow.step_node import get_node
196+
get_node(node.type)(node, None, None)
197+
198+
def is_valid_node(self, node: Node):
199+
self.is_valid_node_params(node)
200+
if node.type == 'condition-node':
201+
branch_list = node.properties.get('node_data').get('branch')
202+
for branch in branch_list:
203+
source_anchor_id = f"{node.id}_{branch.get('id')}_right"
204+
edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
205+
if len(edge_list) == 0:
206+
raise AppApiException(500,
207+
_('The branch {branch} of the {node} node needs to be connected').format(
208+
node=node.properties.get("stepName"), branch=branch.get("type")))
209+
210+
else:
211+
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
212+
if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
213+
raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format(
214+
node=node.properties.get("stepName")))
215+
216+
def is_valid_work_flow(self, up_node=None):
217+
if up_node is None:
218+
up_node = self.get_start_node()
219+
self.is_valid_node(up_node)
220+
next_nodes = self.get_next_nodes(up_node)
221+
for next_node in next_nodes:
222+
self.is_valid_work_flow(next_node)
223+
224+
def is_valid_start_node(self):
225+
start_node_list = [node for node in self.nodes if node.id == 'start-node']
226+
if len(start_node_list) == 0:
227+
raise AppApiException(500, _('The starting node is required'))
228+
if len(start_node_list) > 1:
229+
raise AppApiException(500, _('There can only be one starting node'))
230+
231+
def is_valid_model_params(self):
232+
node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
233+
for node in node_list:
234+
model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
235+
if model is None:
236+
raise ValidationError(ErrorDetail(
237+
_('The node {node} model does not exist').format(node=node.properties.get("stepName"))))
238+
credential = get_model_credential(model.provider, model.model_type, model.model_name)
239+
model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
240+
model_params_setting_form = credential.get_model_params_setting_form(
241+
model.model_name)
242+
if model_params_setting is None:
243+
model_params_setting = model_params_setting_form.get_default_form_data()
244+
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
245+
if node.properties.get('status', 200) != 200:
246+
raise ValidationError(
247+
ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName"))))
248+
node_list = [node for node in self.nodes if (node.type == 'function-lib-node')]
249+
for node in node_list:
250+
function_lib_id = node.properties.get('node_data', {}).get('function_lib_id')
251+
if function_lib_id is None:
252+
raise ValidationError(ErrorDetail(
253+
_('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName"))))
254+
f_lib = QuerySet(Tool).filter(id=function_lib_id).first()
255+
if f_lib is None:
256+
raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format(
257+
node=node.properties.get("stepName"))))
258+
259+
def is_valid_base_node(self):
260+
base_node_list = [node for node in self.nodes if node.id == 'base-node']
261+
if len(base_node_list) == 0:
262+
raise AppApiException(500, _('Basic information node is required'))
263+
if len(base_node_list) > 1:
264+
raise AppApiException(500, _('There can only be one basic information node'))

0 commit comments

Comments
 (0)