Skip to content

Commit dd99133

Browse files
committed
fix: 修复子应用表单调用无法调用问题
1 parent 433ae5d commit dd99133

File tree

18 files changed

+273
-93
lines changed

18 files changed

+273
-93
lines changed

apps/application/flow/i_step_node.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
4040
node.context['run_time'] = time.time() - node.context['start_time']
4141

4242

43+
def is_interrupt(node, step_variable: Dict, global_variable: Dict):
44+
return node.type == 'form-node' and not node.context.get('is_submit', False)
45+
46+
4347
class WorkFlowPostHandler:
4448
def __init__(self, chat_info, client_id, client_type):
4549
self.chat_info = chat_info
@@ -57,7 +61,7 @@ def handler(self, chat_id,
5761
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
5862
'answer_tokens' in row and row.get('answer_tokens') is not None])
5963
answer_text_list = workflow.get_answer_text_list()
60-
answer_text = '\n\n'.join(answer_text_list)
64+
answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list)
6165
if workflow.chat_record is not None:
6266
chat_record = workflow.chat_record
6367
chat_record.answer_text = answer_text
@@ -91,17 +95,26 @@ def handler(self, chat_id,
9195

9296
class NodeResult:
9397
def __init__(self, node_variable: Dict, workflow_variable: Dict,
94-
_write_context=write_context):
98+
_write_context=write_context, _is_interrupt=is_interrupt):
9599
self._write_context = _write_context
96100
self.node_variable = node_variable
97101
self.workflow_variable = workflow_variable
102+
self._is_interrupt = _is_interrupt
98103

99104
def write_context(self, node, workflow):
100105
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
101106

102107
def is_assertion_result(self):
103108
return 'branch_id' in self.node_variable
104109

110+
def is_interrupt_exec(self, current_node):
111+
"""
112+
是否中断执行
113+
@param current_node:
114+
@return:
115+
"""
116+
return self._is_interrupt(current_node, self.node_variable, self.workflow_variable)
117+
105118

106119
class ReferenceAddressSerializer(serializers.Serializer):
107120
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
@@ -139,14 +152,18 @@ def save_context(self, details, workflow_manage):
139152
pass
140153

141154
def get_answer_text(self):
142-
return self.answer_text
155+
if self.answer_text is None:
156+
return None
157+
return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id,
158+
'chat_record_id': self.workflow_params['chat_record_id']}
143159

144-
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None):
160+
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
161+
get_node_params=lambda node: node.properties.get('node_data')):
145162
# 当前步骤上下文,用于存储当前步骤信息
146163
self.status = 200
147164
self.err_message = ''
148165
self.node = node
149-
self.node_params = node.properties.get('node_data')
166+
self.node_params = get_node_params(node)
150167
self.workflow_params = workflow_params
151168
self.workflow_manage = workflow_manage
152169
self.node_params_serializer = None

apps/application/flow/step_node/application_node/i_application_node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class ApplicationNodeSerializer(serializers.Serializer):
1414
user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段"))
1515
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
1616
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
17+
child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))
18+
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据"))
1719

1820

1921
class IApplicationNode(INode):
@@ -55,5 +57,5 @@ def _run(self):
5557
message=str(question), **kwargs)
5658

5759
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
58-
app_document_list=None, app_image_list=None, **kwargs) -> NodeResult:
60+
app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult:
5961
pass

apps/application/flow/step_node/application_node/impl/base_application_node.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,25 @@
22
import json
33
import time
44
import uuid
5-
from typing import List, Dict
5+
from typing import Dict
6+
67
from application.flow.i_step_node import NodeResult, INode
78
from application.flow.step_node.application_node.i_application_node import IApplicationNode
89
from application.models import Chat
9-
from common.handle.impl.response.openai_to_response import OpenaiToResponse
1010

1111

1212
def string_to_uuid(input_str):
1313
return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str))
1414

1515

16+
def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
17+
return node_variable.get('is_interrupt_exec', False)
18+
19+
1620
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
1721
result = node_variable.get('result')
22+
node.context['child_node'] = node_variable['child_node']
23+
node.context['is_interrupt_exec'] = node_variable['is_interrupt_exec']
1824
node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0)
1925
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
2026
node.context['answer'] = answer
@@ -36,17 +42,34 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
3642
response = node_variable.get('result')
3743
answer = ''
3844
usage = {}
45+
node_child_node = {}
46+
is_interrupt_exec = False
3947
for chunk in response:
4048
# 先把流转成字符串
4149
response_content = chunk.decode('utf-8')[6:]
4250
response_content = json.loads(response_content)
43-
choices = response_content.get('choices')
44-
if choices and isinstance(choices, list) and len(choices) > 0:
45-
content = choices[0].get('delta', {}).get('content', '')
46-
answer += content
47-
yield content
51+
content = response_content.get('content', '')
52+
runtime_node_id = response_content.get('runtime_node_id', '')
53+
chat_record_id = response_content.get('chat_record_id', '')
54+
child_node = response_content.get('child_node')
55+
node_type = response_content.get('node_type')
56+
real_node_id = response_content.get('real_node_id')
57+
node_is_end = response_content.get('node_is_end', False)
58+
if node_type == 'form-node':
59+
is_interrupt_exec = True
60+
answer += content
61+
node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
62+
'child_node': child_node}
63+
yield {'content': content,
64+
'node_type': node_type,
65+
'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
66+
'child_node': child_node,
67+
'real_node_id': real_node_id,
68+
'node_is_end': node_is_end}
4869
usage = response_content.get('usage', {})
4970
node_variable['result'] = {'usage': usage}
71+
node_variable['is_interrupt_exec'] = is_interrupt_exec
72+
node_variable['child_node'] = node_child_node
5073
_write_context(node_variable, workflow_variable, node, workflow, answer)
5174

5275

@@ -64,6 +87,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
6487

6588

6689
class BaseApplicationNode(IApplicationNode):
90+
def get_answer_text(self):
91+
if self.answer_text is None:
92+
return None
93+
return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id,
94+
'chat_record_id': self.workflow_params['chat_record_id'], 'child_node': self.context.get('child_node')}
6795

6896
def save_context(self, details, workflow_manage):
6997
self.context['answer'] = details.get('answer')
@@ -72,7 +100,7 @@ def save_context(self, details, workflow_manage):
72100
self.answer_text = details.get('answer')
73101

74102
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
75-
app_document_list=None, app_image_list=None,
103+
app_document_list=None, app_image_list=None, child_node=None, node_data=None,
76104
**kwargs) -> NodeResult:
77105
from application.serializers.chat_message_serializers import ChatMessageSerializer
78106
# 生成嵌入应用的chat_id
@@ -85,6 +113,14 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
85113
app_document_list = []
86114
if app_image_list is None:
87115
app_image_list = []
116+
runtime_node_id = None
117+
record_id = None
118+
child_node_value = None
119+
if child_node is not None:
120+
runtime_node_id = child_node.get('runtime_node_id')
121+
record_id = child_node.get('chat_record_id')
122+
child_node_value = child_node.get('child_node')
123+
88124
response = ChatMessageSerializer(
89125
data={'chat_id': current_chat_id, 'message': message,
90126
're_chat': re_chat,
@@ -94,16 +130,20 @@ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_c
94130
'client_type': client_type,
95131
'document_list': app_document_list,
96132
'image_list': app_image_list,
97-
'form_data': kwargs}).chat(base_to_response=OpenaiToResponse())
133+
'runtime_node_id': runtime_node_id,
134+
'chat_record_id': record_id,
135+
'child_node': child_node_value,
136+
'node_data': node_data,
137+
'form_data': kwargs}).chat()
98138
if response.status_code == 200:
99139
if stream:
100140
content_generator = response.streaming_content
101141
return NodeResult({'result': content_generator, 'question': message}, {},
102-
_write_context=write_context_stream)
142+
_write_context=write_context_stream, _is_interrupt=_is_interrupt_exec)
103143
else:
104144
data = json.loads(response.content)
105145
return NodeResult({'result': data, 'question': message}, {},
106-
_write_context=write_context)
146+
_write_context=write_context, _is_interrupt=_is_interrupt_exec)
107147

108148
def get_details(self, index: int, **kwargs):
109149
global_fields = []

apps/application/flow/step_node/form_node/i_form_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
class FormNodeParamsSerializer(serializers.Serializer):
1818
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置"))
1919
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容'))
20+
form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据"))
2021

2122

2223
class IFormNode(INode):
@@ -29,5 +30,5 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
2930
def _run(self):
3031
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
3132

32-
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
33+
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
3334
pass

apps/application/flow/step_node/form_node/impl/base_form_node.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ def save_context(self, details, workflow_manage):
4242
for key in form_data:
4343
self.context[key] = form_data[key]
4444

45-
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
45+
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
46+
if form_data is not None:
47+
self.context['is_submit'] = True
48+
self.context['form_data'] = form_data
49+
else:
50+
self.context['is_submit'] = False
4651
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
4752
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
4853
"is_submit": self.context.get("is_submit", False)}
@@ -63,7 +68,8 @@ def get_answer_text(self):
6368
form = f'<form_rander>{json.dumps(form_setting)}</form_rander>'
6469
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
6570
value = prompt_template.format(form=form)
66-
return value
71+
return {'content': value, 'runtime_node_id': self.runtime_node_id,
72+
'chat_record_id': self.workflow_params['chat_record_id']}
6773

6874
def get_details(self, index: int, **kwargs):
6975
form_content_format = self.context.get('form_content_format')

0 commit comments

Comments
 (0)