Skip to content

Commit 7c06134

Browse files
committed
feat: Support reasoning content(WIP)
1 parent 12bca69 commit 7c06134

File tree

6 files changed

+50
-21
lines changed

6 files changed

+50
-21
lines changed

apps/application/flow/common.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@
99

1010

1111
class Answer:
12-
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, reasoning_content=None):
12+
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
13+
reasoning_content):
1314
self.view_type = view_type
1415
self.content = content
1516
self.reasoning_content = reasoning_content
1617
self.runtime_node_id = runtime_node_id
1718
self.chat_record_id = chat_record_id
1819
self.child_node = child_node
20+
self.real_node_id = real_node_id
1921

2022
def to_dict(self):
2123
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
22-
'chat_record_id': self.chat_record_id, 'child_node': self.child_node,
23-
'reasoning_content': self.reasoning_content}
24+
'chat_record_id': self.chat_record_id,
25+
'child_node': self.child_node,
26+
'reasoning_content': self.reasoning_content,
27+
'real_node_id': self.real_node_id}
2428

2529

2630
class NodeChunk:

apps/application/flow/i_step_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def get_answer_list(self) -> List[Answer] | None:
158158
if self.answer_text is None:
159159
return None
160160
return [
161-
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})]
161+
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
162+
self.runtime_node_id, self.context.get('reasoning_content'))]
162163

163164
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
164165
get_node_params=lambda node: node.properties.get('node_data')):

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
6969
"""
7070
response = node_variable.get('result')
7171
answer = response.content
72-
_write_context(node_variable, workflow_variable, node, workflow, answer)
72+
reasoning_content = response.response_metadata.get('reasoning_content', '')
73+
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
7374

7475

7576
def get_default_model_params_setting(model_id):
@@ -103,13 +104,6 @@ def save_context(self, details, workflow_manage):
103104
self.context['reasoning_content'] = details.get('reasoning_content')
104105
self.answer_text = details.get('answer')
105106

106-
def get_answer_list(self) -> List[Answer] | None:
107-
if self.answer_text is None:
108-
return None
109-
return [
110-
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
111-
self.context.get('reasoning_content'))]
112-
113107
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
114108
model_params_setting=None,
115109
dialogue_type=None,

apps/application/flow/step_node/application_node/impl/base_application_node.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
1919
return node_variable.get('is_interrupt_exec', False)
2020

2121

22-
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
22+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
23+
reasoning_content: str):
2324
result = node_variable.get('result')
2425
node.context['application_node_dict'] = node_variable.get('application_node_dict')
2526
node.context['node_dict'] = node_variable.get('node_dict', {})
@@ -28,6 +29,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
2829
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
2930
node.context['answer'] = answer
3031
node.context['result'] = answer
32+
node.context['reasoning_content'] = reasoning_content
3133
node.context['question'] = node_variable['question']
3234
node.context['run_time'] = time.time() - node.context['start_time']
3335
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
@@ -44,6 +46,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
4446
"""
4547
response = node_variable.get('result')
4648
answer = ''
49+
reasoning_content = ''
4750
usage = {}
4851
node_child_node = {}
4952
application_node_dict = node.context.get('application_node_dict', {})
@@ -60,9 +63,11 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
6063
node_type = response_content.get('node_type')
6164
real_node_id = response_content.get('real_node_id')
6265
node_is_end = response_content.get('node_is_end', False)
66+
_reasoning_content = response_content.get('reasoning_content', '')
6367
if node_type == 'form-node':
6468
is_interrupt_exec = True
6569
answer += content
70+
reasoning_content += _reasoning_content
6671
node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
6772
'child_node': child_node}
6873

@@ -75,13 +80,16 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
7580
'chat_record_id': chat_record_id,
7681
'child_node': child_node,
7782
'index': len(application_node_dict),
78-
'view_type': view_type}
83+
'view_type': view_type,
84+
'reasoning_content': _reasoning_content}
7985
else:
8086
application_node['content'] += content
87+
application_node['reasoning_content'] += _reasoning_content
8188

8289
yield {'content': content,
8390
'node_type': node_type,
8491
'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
92+
'reasoning_content': _reasoning_content,
8593
'child_node': child_node,
8694
'real_node_id': real_node_id,
8795
'node_is_end': node_is_end,
@@ -91,7 +99,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
9199
node_variable['is_interrupt_exec'] = is_interrupt_exec
92100
node_variable['child_node'] = node_child_node
93101
node_variable['application_node_dict'] = application_node_dict
94-
_write_context(node_variable, workflow_variable, node, workflow, answer)
102+
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
95103

96104

97105
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@@ -106,7 +114,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
106114
node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'),
107115
'prompt_tokens': response.get('prompt_tokens')}}
108116
answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。"
109-
_write_context(node_variable, workflow_variable, node, workflow, answer)
117+
reasoning_content = response.get('reasoning_content', '')
118+
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
110119

111120

112121
def reset_application_node_dict(application_node_dict, runtime_node_id, node_data):
@@ -139,18 +148,21 @@ def get_answer_list(self) -> List[Answer] | None:
139148
if application_node_dict is None or len(application_node_dict) == 0:
140149
return [
141150
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'],
142-
self.context.get('child_node'))]
151+
self.context.get('child_node'), self.runtime_node_id, '')]
143152
else:
144153
return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id,
145154
self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'),
146155
'chat_record_id': n.get('chat_record_id')
147-
, 'child_node': n.get('child_node')}) for n in
156+
, 'child_node': n.get('child_node')}, n.get('real_node_id'), n.get('reasoning_content'))
157+
for n in
148158
sorted(application_node_dict.values(), key=lambda item: item.get('index'))]
149159

150160
def save_context(self, details, workflow_manage):
151161
self.context['answer'] = details.get('answer')
162+
self.context['result'] = details.get('answer')
152163
self.context['question'] = details.get('question')
153164
self.context['type'] = details.get('type')
165+
self.context['reasoning_content'] = details.get('reasoning_content')
154166
self.answer_text = details.get('answer')
155167

156168
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
@@ -229,6 +241,7 @@ def get_details(self, index: int, **kwargs):
229241
'run_time': self.context.get('run_time'),
230242
'question': self.context.get('question'),
231243
'answer': self.context.get('answer'),
244+
'reasoning_content': self.context.get('reasoning_content'),
232245
'type': self.node.type,
233246
'message_tokens': self.context.get('message_tokens'),
234247
'answer_tokens': self.context.get('answer_tokens'),

apps/application/flow/step_node/form_node/impl/base_form_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def get_answer_list(self) -> List[Answer] | None:
7575
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
7676
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
7777
value = prompt_template.format(form=form, context=context)
78-
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None)]
78+
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
79+
self.runtime_node_id, '')]
7980

8081
def get_details(self, index: int, **kwargs):
8182
form_content_format = self.context.get('form_content_format')

apps/setting/models_provider/impl/base_chat_open_ai.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# coding=utf-8
22
import warnings
3-
from typing import List, Dict, Optional, Any, Iterator, cast, Type
3+
from typing import List, Dict, Optional, Any, Iterator, cast, Type, Union
44

55
import openai
66
from langchain_core.callbacks import CallbackManagerForLLMRun
@@ -105,7 +105,8 @@ def _stream(
105105
self.usage_metadata = generation_chunk.message.usage_metadata
106106
# custom code
107107
if 'reasoning_content' in chunk['choices'][0]['delta']:
108-
generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta']['reasoning_content']
108+
generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta'][
109+
'reasoning_content']
109110

110111
default_chunk_class = generation_chunk.message.__class__
111112
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
@@ -116,6 +117,21 @@ def _stream(
116117
is_first_chunk = False
117118
yield generation_chunk
118119

120+
def _create_chat_result(self,
121+
response: Union[dict, openai.BaseModel],
122+
generation_info: Optional[Dict] = None):
123+
result = super()._create_chat_result(response, generation_info)
124+
try:
125+
reasoning_content = ''
126+
for res in response.choices:
127+
_reasoning_content = res.message.model_extra.get('reasoning_content')
128+
if _reasoning_content is not None:
129+
reasoning_content += _reasoning_content
130+
result.llm_output['reasoning_content'] = reasoning_content
131+
except Exception as e:
132+
pass
133+
return result
134+
119135
def invoke(
120136
self,
121137
input: LanguageModelInput,

0 commit comments

Comments
 (0)