Skip to content

Commit 83a1ffb

Browse files
authored
feat: Support session variables (#3792)
1 parent 0e78245 commit 83a1ffb

File tree

15 files changed

+404
-48
lines changed

15 files changed

+404
-48
lines changed

apps/application/flow/step_node/start_node/impl/base_start_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_default_global_variable(input_field_list: List):
2323
if item.get('default_value', None) is not None
2424
}
2525

26+
2627
def get_global_variable(node):
2728
body = node.workflow_manage.get_body()
2829
history_chat_record = node.flow_params_serializer.data.get('history_chat_record', [])
@@ -74,6 +75,7 @@ def execute(self, question, **kwargs) -> NodeResult:
7475
'other': self.workflow_manage.other_list,
7576

7677
}
78+
self.workflow_manage.chat_context = self.workflow_manage.get_chat_info().get_chat_variable()
7779
return NodeResult(node_variable, workflow_variable)
7880

7981
def get_details(self, index: int, **kwargs):

apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,68 @@
22
import json
33
from typing import List
44

5+
from django.db.models import QuerySet
6+
57
from application.flow.i_step_node import NodeResult
68
from application.flow.step_node.variable_assign_node.i_variable_assign_node import IVariableAssignNode
9+
from application.models import Chat
710

811

912
class BaseVariableAssignNode(IVariableAssignNode):
1013
def save_context(self, details, workflow_manage):
1114
self.context['variable_list'] = details.get('variable_list')
1215
self.context['result_list'] = details.get('result_list')
1316

17+
def global_evaluation(self, variable, value):
18+
self.workflow_manage.context[variable['fields'][1]] = value
19+
20+
def chat_evaluation(self, variable, value):
21+
self.workflow_manage.chat_context[variable['fields'][1]] = value
22+
23+
def handle(self, variable, evaluation):
24+
result = {
25+
'name': variable['name'],
26+
'input_value': self.get_reference_content(variable['fields']),
27+
}
28+
if variable['source'] == 'custom':
29+
if variable['type'] == 'json':
30+
if isinstance(variable['value'], dict) or isinstance(variable['value'], list):
31+
val = variable['value']
32+
else:
33+
val = json.loads(variable['value'])
34+
evaluation(variable, val)
35+
result['output_value'] = variable['value'] = val
36+
elif variable['type'] == 'string':
37+
# 变量解析 例如:{{global.xxx}}
38+
val = self.workflow_manage.generate_prompt(variable['value'])
39+
evaluation(variable, val)
40+
result['output_value'] = val
41+
else:
42+
val = variable['value']
43+
evaluation(variable, val)
44+
result['output_value'] = val
45+
else:
46+
reference = self.get_reference_content(variable['reference'])
47+
evaluation(variable, reference)
48+
result['output_value'] = reference
49+
return result
50+
1451
def execute(self, variable_list, stream, **kwargs) -> NodeResult:
1552
#
1653
result_list = []
54+
is_chat = False
1755
for variable in variable_list:
1856
if 'fields' not in variable:
1957
continue
2058
if 'global' == variable['fields'][0]:
21-
result = {
22-
'name': variable['name'],
23-
'input_value': self.get_reference_content(variable['fields']),
24-
}
25-
if variable['source'] == 'custom':
26-
if variable['type'] == 'json':
27-
if isinstance(variable['value'], dict) or isinstance(variable['value'], list):
28-
val = variable['value']
29-
else:
30-
val = json.loads(variable['value'])
31-
self.workflow_manage.context[variable['fields'][1]] = val
32-
result['output_value'] = variable['value'] = val
33-
elif variable['type'] == 'string':
34-
# 变量解析 例如:{{global.xxx}}
35-
val = self.workflow_manage.generate_prompt(variable['value'])
36-
self.workflow_manage.context[variable['fields'][1]] = val
37-
result['output_value'] = val
38-
else:
39-
val = variable['value']
40-
self.workflow_manage.context[variable['fields'][1]] = val
41-
result['output_value'] = val
42-
else:
43-
reference = self.get_reference_content(variable['reference'])
44-
self.workflow_manage.context[variable['fields'][1]] = reference
45-
result['output_value'] = reference
59+
result = self.handle(variable, self.global_evaluation)
4660
result_list.append(result)
47-
61+
if 'chat' == variable['fields'][0]:
62+
result = self.handle(variable, self.chat_evaluation)
63+
result_list.append(result)
64+
is_chat = True
65+
if is_chat:
66+
self.workflow_manage.get_chat_info().set_chat_variable(self.workflow_manage.chat_context)
4867
return NodeResult({'variable_list': variable_list, 'result_list': result_list}, {})
4968

5069
def get_reference_content(self, fields: List[str]):

apps/application/flow/workflow_manage.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
117117
self.params = params
118118
self.flow = flow
119119
self.context = {}
120+
self.chat_context = {}
120121
self.node_chunk_manage = NodeChunkManage(self)
121122
self.work_flow_post_handler = work_flow_post_handler
122123
self.current_node = None
@@ -131,6 +132,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
131132
self.lock = threading.Lock()
132133
self.field_list = []
133134
self.global_field_list = []
135+
self.chat_field_list = []
134136
self.init_fields()
135137
if start_node_id is not None:
136138
self.load_node(chat_record, start_node_id, start_node_data)
@@ -140,6 +142,7 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
140142
def init_fields(self):
141143
field_list = []
142144
global_field_list = []
145+
chat_field_list = []
143146
for node in self.flow.nodes:
144147
properties = node.properties
145148
node_name = properties.get('stepName')
@@ -154,10 +157,16 @@ def init_fields(self):
154157
if global_fields is not None:
155158
for global_field in global_fields:
156159
global_field_list.append({**global_field, 'node_id': node_id, 'node_name': node_name})
160+
chat_fields = node_config.get('chatFields')
161+
if chat_fields is not None:
162+
for chat_field in chat_fields:
163+
chat_field_list.append({**chat_field, 'node_id': node_id, 'node_name': node_name})
157164
field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True)
158165
global_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True)
166+
chat_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True)
159167
self.field_list = field_list
160168
self.global_field_list = global_field_list
169+
self.chat_field_list = chat_field_list
161170

162171
def append_answer(self, content):
163172
self.answer += content
@@ -445,6 +454,9 @@ def is_result(self, current_node, current_node_result):
445454
return current_node.node_params.get('is_result', not self._has_next_node(
446455
current_node, current_node_result)) if current_node.node_params is not None else False
447456

457+
def get_chat_info(self):
458+
return self.work_flow_post_handler.chat_info
459+
448460
def get_chunk_content(self, chunk, is_end=False):
449461
return 'data: ' + json.dumps(
450462
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
@@ -587,12 +599,15 @@ def get_reference_field(self, node_id: str, fields: List[str]):
587599
"""
588600
if node_id == 'global':
589601
return INode.get_field(self.context, fields)
602+
elif node_id == 'chat':
603+
return INode.get_field(self.chat_context, fields)
590604
else:
591605
return self.get_node_by_id(node_id).get_reference_field(fields)
592606

593607
def get_workflow_content(self):
594608
context = {
595609
'global': self.context,
610+
'chat': self.chat_context
596611
}
597612

598613
for node in self.node_context:
@@ -610,6 +625,10 @@ def reset_prompt(self, prompt: str):
610625
globeLabelNew = f"global.{field.get('value')}"
611626
globeValue = f"context.get('global').get('{field.get('value', '')}','')"
612627
prompt = prompt.replace(globeLabel, globeValue).replace(globeLabelNew, globeValue)
628+
for field in self.chat_field_list:
629+
chatLabel = f"chat.{field.get('value')}"
630+
chatValue = f"context.get('chat').get('{field.get('value', '')}','')"
631+
prompt = prompt.replace(chatLabel, chatValue)
613632
return prompt
614633

615634
def generate_prompt(self, prompt: str):

apps/application/serializers/common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,34 @@ def to_pipeline_manage_params(self, problem_text: str, post_response_handler: Po
166166
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'chat_user_id': chat_user_id,
167167
'chat_user_type': chat_user_type, 'form_data': form_data}
168168

169+
def set_chat(self, question):
170+
if not self.debug:
171+
if not QuerySet(Chat).filter(id=self.chat_id).exists():
172+
Chat(id=self.chat_id, application_id=self.application_id, abstract=question[0:1024],
173+
chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type,
174+
asker=self.get_chat_user()).save()
175+
176+
def set_chat_variable(self, chat_context):
177+
if not self.debug:
178+
chat = QuerySet(Chat).filter(id=self.chat_id).first()
179+
if chat:
180+
chat.meta = {**(chat.meta if isinstance(chat.meta, dict) else {}), **chat_context}
181+
chat.save()
182+
else:
183+
cache.set(Cache_Version.CHAT_VARIABLE.get_key(key=self.chat_id), chat_context,
184+
version=Cache_Version.CHAT_VARIABLE.get_version(),
185+
timeout=60 * 30)
186+
187+
def get_chat_variable(self):
188+
if not self.debug:
189+
chat = QuerySet(Chat).filter(id=self.chat_id).first()
190+
if chat:
191+
return chat.meta
192+
return {}
193+
else:
194+
return cache.get(Cache_Version.CHAT_VARIABLE.get_key(key=self.chat_id),
195+
version=Cache_Version.CHAT_VARIABLE.get_version()) or {}
196+
169197
def append_chat_record(self, chat_record: ChatRecord):
170198
chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else ""
171199
chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else ""

apps/chat/serializers/chat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def chat_simple(self, chat_info: ChatInfo, instance, base_to_response):
253253
# 构建运行参数
254254
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
255255
chat_user_id, chat_user_type, stream, form_data)
256+
chat_info.set_chat(message)
256257
# 运行流水线作业
257258
pipeline_message.run(params)
258259
return pipeline_message.context['chat_result']
@@ -307,6 +308,7 @@ def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
307308
other_list,
308309
instance.get('runtime_node_id'),
309310
instance.get('node_data'), chat_record, instance.get('child_node'))
311+
chat_info.set_chat(message)
310312
r = work_flow_manage.run()
311313
return r
312314

apps/common/constants/cache_version.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class Cache_Version(Enum):
2929

3030
# 对话
3131
CHAT = "CHAT", lambda key: key
32+
33+
CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key
34+
3235
# 应用API KEY
3336
APPLICATION_API_KEY = "APPLICATION_API_KEY", lambda secret_key, use_get_data: secret_key
3437

ui/src/locales/lang/zh-CN/views/application-workflow.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ export default {
5252
variable: {
5353
label: '变量',
5454
global: '全局变量',
55+
chat: '会话变量',
5556
Referencing: '引用变量',
5657
ReferencingRequired: '引用变量必填',
5758
ReferencingError: '引用变量错误',

ui/src/workflow/common/NodeCascader.vue

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ const wheel = (e: any) => {
5151
function visibleChange(bool: boolean) {
5252
if (bool) {
5353
options.value = props.global
54-
? props.nodeModel.get_up_node_field_list(false, true).filter((v: any) => v.value === 'global')
54+
? props.nodeModel
55+
.get_up_node_field_list(false, true)
56+
.filter((v: any) => ['global', 'chat'].includes(v.value))
5557
: props.nodeModel.get_up_node_field_list(false, true)
5658
}
5759
}

0 commit comments

Comments
 (0)