Skip to content

Commit 458d28a

Browse files
committed
feat: 支持表单收集功能 30%
1 parent 44b3aed commit 458d28a

File tree

30 files changed

+732
-45
lines changed

30 files changed

+732
-45
lines changed

apps/application/flow/i_step_node.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
@desc:
88
"""
99
import time
10+
import uuid
1011
from abc import abstractmethod
1112
from typing import Type, Dict, List
1213

@@ -118,7 +119,12 @@ class FlowParamsSerializer(serializers.Serializer):
118119

119120

120121
class INode:
121-
def __init__(self, node, workflow_params, workflow_manage):
122+
123+
@abstractmethod
124+
def save_context(self, details, workflow_manage):
125+
pass
126+
127+
def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None):
122128
# 当前步骤上下文,用于存储当前步骤信息
123129
self.status = 200
124130
self.err_message = ''
@@ -130,6 +136,10 @@ def __init__(self, node, workflow_params, workflow_manage):
130136
self.flow_params_serializer = None
131137
self.context = {}
132138
self.id = node.id
139+
if runtime_node_id is None:
140+
self.runtime_node_id = str(uuid.uuid1())
141+
else:
142+
self.runtime_node_id = runtime_node_id
133143

134144
def valid_args(self, node_params, flow_params):
135145
flow_params_serializer_class = self.get_flow_params_serializer_class()

apps/application/flow/step_node/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,17 @@
88
"""
99
from .ai_chat_step_node import *
1010
from .condition_node import *
11-
from .question_node import *
12-
from .search_dataset_node import *
13-
from .start_node import *
1411
from .direct_reply_node import *
12+
from .form_node import *
1513
from .function_lib_node import *
1614
from .function_node import *
15+
from .question_node import *
1716
from .reranker_node import *
17+
from .search_dataset_node import *
18+
from .start_node import *
1819

1920
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
20-
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode]
21+
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseFormNode]
2122

2223

2324
def get_node(node_type):

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def get_default_model_params_setting(model_id):
7373

7474

7575
class BaseChatNode(IChatNode):
76+
def save_context(self, details, workflow_manage):
77+
self.context['answer'] = details.get('answer')
78+
self.context['question'] = details.get('question')
79+
7680
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
7781
model_params_setting=None,
7882
**kwargs) -> NodeResult:

apps/application/flow/step_node/condition_node/impl/base_condition_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515

1616
class BaseConditionNode(IConditionNode):
17+
def save_context(self, details, workflow_manage):
18+
self.context['branch_id'] = details.get('branch_id')
19+
self.context['branch_name'] = details.get('branch_name')
20+
1721
def execute(self, **kwargs) -> NodeResult:
1822
branch_list = self.node_params_serializer.data['branch']
1923
branch = self._execute(branch_list)

apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414

1515
class BaseReplyNode(IReplyNode):
16+
def save_context(self, details, workflow_manage):
17+
self.context['answer'] = details.get('answer')
18+
1619
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
1720
if reply_type == 'referencing':
1821
result = self.get_reference_content(fields)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/11/4 14:48
7+
@desc:
8+
"""
9+
from .impl import *
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: i_form_node.py
6+
@date:2024/11/4 14:48
7+
@desc:
8+
"""
9+
from typing import Type
10+
11+
from rest_framework import serializers
12+
13+
from application.flow.i_step_node import INode, NodeResult
14+
from common.util.field_message import ErrMessage
15+
16+
17+
class FormNodeParamsSerializer(serializers.Serializer):
18+
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置"))
19+
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容'))
20+
21+
22+
class IFormNode(INode):
23+
type = 'form-node'
24+
25+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
26+
return FormNodeParamsSerializer
27+
28+
def _run(self):
29+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
30+
31+
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
32+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/11/4 14:49
7+
@desc:
8+
"""
9+
from .base_form_node import BaseFormNode
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: base_form_node.py
6+
@date:2024/11/4 14:52
7+
@desc:
8+
"""
9+
import json
10+
import time
11+
from typing import Dict
12+
13+
from langchain_core.prompts import PromptTemplate
14+
15+
from application.flow.i_step_node import NodeResult
16+
from application.flow.step_node.form_node.i_form_node import IFormNode
17+
18+
19+
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
20+
if step_variable is not None:
21+
for key in step_variable:
22+
node.context[key] = step_variable[key]
23+
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
24+
result = step_variable['result']
25+
yield result
26+
workflow.answer += result
27+
node.context['run_time'] = time.time() - node.context['start_time']
28+
29+
30+
class BaseFormNode(IFormNode):
31+
def save_context(self, details, workflow_manage):
32+
self.context['result'] = details.get('result')
33+
self.context['form_content_format'] = details.get('form_content_format')
34+
self.context['form_field_list'] = details.get('form_field_list')
35+
self.context['run_time'] = details.get('run_time')
36+
self.context['start_time'] = details.get('start_time')
37+
38+
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
39+
form = f'<form_rander>{json.dumps({"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, "chat_record_id": self.flow_params_serializer.data.get("chat_record_id")})}</form_rander>'
40+
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
41+
value = prompt_template.format(form=form)
42+
return NodeResult(
43+
{'result': value, 'form_field_list': form_field_list, 'form_content_format': 'form_content_format'}, {},
44+
_write_context=write_context)
45+
46+
def get_details(self, index: int, **kwargs):
47+
form_content_format = self.context.get('form_content_format')
48+
form_field_list = self.context.get('form_field_list')
49+
form_data = self.context.get('form_data')
50+
form = f'<form_rander>{json.dumps({"form_field_list": form_field_list, "node_id": self.node.id, "form_data": form_data})}</form_rander>'
51+
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
52+
value = prompt_template.format(form=form)
53+
return {
54+
'name': self.node.properties.get('stepName'),
55+
"index": index,
56+
"result": value,
57+
"form_content_format": self.context.get('form_content_format'),
58+
"form_field_list": self.context.get('form_field_list'),
59+
'form_data': self.context.get('form_data'),
60+
'start_time': self.context.get('start_time'),
61+
'run_time': self.context.get('run_time'),
62+
'type': self.node.type,
63+
'status': self.status,
64+
'err_message': self.err_message
65+
}

apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def convert_value(name: str, value, _type, is_required, source, node):
8989

9090

9191
class BaseFunctionLibNodeNode(IFunctionLibNode):
92+
def save_context(self, details, workflow_manage):
93+
self.context['result'] = details.get('result')
94+
9295
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
9396
function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
9497
if not function_lib.is_active:

0 commit comments

Comments
 (0)