Skip to content

Commit e79d859

Browse files
committed
feat: Support reasoning content(WIP)
1 parent 5d3cd58 commit e79d859

File tree

8 files changed

+206
-93
lines changed

8 files changed

+206
-93
lines changed

apps/application/flow/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99

1010

1111
class Answer:
12-
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node):
12+
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, reasoning_content=None):
1313
self.view_type = view_type
1414
self.content = content
15+
self.reasoning_content = reasoning_content
1516
self.runtime_node_id = runtime_node_id
1617
self.chat_record_id = chat_record_id
1718
self.child_node = child_node
1819

1920
def to_dict(self):
2021
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
21-
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
22+
'chat_record_id': self.chat_record_id, 'child_node': self.child_node,
23+
'reasoning_content': self.reasoning_content}
2224

2325

2426
class NodeChunk:

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
from langchain.schema import HumanMessage, SystemMessage
1515
from langchain_core.messages import BaseMessage, AIMessage
1616

17+
from application.flow.common import Answer
1718
from application.flow.i_step_node import NodeResult, INode
1819
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
1920
from setting.models import Model
2021
from setting.models_provider import get_model_credential
2122
from setting.models_provider.tools import get_model_instance_by_model_user_id
2223

2324

24-
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
25+
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
26+
reasoning_content: str):
2527
chat_model = node_variable.get('chat_model')
2628
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
2729
answer_tokens = chat_model.get_num_tokens(answer)
@@ -31,6 +33,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
3133
node.context['history_message'] = node_variable['history_message']
3234
node.context['question'] = node_variable['question']
3335
node.context['run_time'] = time.time() - node.context['start_time']
36+
node.context['reasoning_content'] = reasoning_content
3437
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
3538
node.answer_text = answer
3639

@@ -45,10 +48,15 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
4548
"""
4649
response = node_variable.get('result')
4750
answer = ''
51+
reasoning_content = ''
4852
for chunk in response:
4953
answer += chunk.content
50-
yield chunk.content
51-
_write_context(node_variable, workflow_variable, node, workflow, answer)
54+
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
55+
if reasoning_content_chunk is None:
56+
reasoning_content_chunk = ''
57+
reasoning_content += reasoning_content_chunk
58+
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
59+
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
5260

5361

5462
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@@ -92,8 +100,16 @@ class BaseChatNode(IChatNode):
92100
def save_context(self, details, workflow_manage):
93101
self.context['answer'] = details.get('answer')
94102
self.context['question'] = details.get('question')
103+
self.context['reasoning_content'] = details.get('reasoning_content')
95104
self.answer_text = details.get('answer')
96105

106+
def get_answer_list(self) -> List[Answer] | None:
107+
if self.answer_text is None:
108+
return None
109+
return [
110+
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
111+
self.context.get('reasoning_content'))]
112+
97113
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
98114
model_params_setting=None,
99115
dialogue_type=None,
@@ -164,6 +180,7 @@ def get_details(self, index: int, **kwargs):
164180
'history_message') is not None else [])],
165181
'question': self.context.get('question'),
166182
'answer': self.context.get('answer'),
183+
'reasoning_content': self.context.get('reasoning_content'),
167184
'type': self.node.type,
168185
'message_tokens': self.context.get('message_tokens'),
169186
'answer_tokens': self.context.get('answer_tokens'),

apps/application/flow/workflow_manage.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def hand_event_node_result(self, current_node, node_result_future):
470470
if result is not None:
471471
if self.is_result(current_node, current_result):
472472
for r in result:
473+
reasoning_content = ''
473474
content = r
474475
child_node = {}
475476
node_is_end = False
@@ -479,9 +480,12 @@ def hand_event_node_result(self, current_node, node_result_future):
479480
child_node = {'runtime_node_id': r.get('runtime_node_id'),
480481
'chat_record_id': r.get('chat_record_id')
481482
, 'child_node': r.get('child_node')}
482-
real_node_id = r.get('real_node_id')
483-
node_is_end = r.get('node_is_end')
483+
if r.__contains__('real_node_id'):
484+
real_node_id = r.get('real_node_id')
485+
if r.__contains__('node_is_end'):
486+
node_is_end = r.get('node_is_end')
484487
view_type = r.get('view_type')
488+
reasoning_content = r.get('reasoning_content')
485489
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
486490
self.params['chat_record_id'],
487491
current_node.id,
@@ -492,7 +496,8 @@ def hand_event_node_result(self, current_node, node_result_future):
492496
'view_type': view_type,
493497
'child_node': child_node,
494498
'node_is_end': node_is_end,
495-
'real_node_id': real_node_id})
499+
'real_node_id': real_node_id,
500+
'reasoning_content': reasoning_content})
496501
current_node.node_chunk.add_chunk(chunk)
497502
chunk = (self.base_to_response
498503
.to_stream_chunk_response(self.params['chat_id'],
@@ -504,7 +509,8 @@ def hand_event_node_result(self, current_node, node_result_future):
504509
'node_type': current_node.type,
505510
'view_type': view_type,
506511
'child_node': child_node,
507-
'real_node_id': real_node_id}))
512+
'real_node_id': real_node_id,
513+
'reasoning_content': ''}))
508514
current_node.node_chunk.add_chunk(chunk)
509515
else:
510516
list(result)

ui/src/api/type/application.ts

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ interface Chunk {
2929
chat_id: string
3030
chat_record_id: string
3131
content: string
32+
reasoning_content: string
3233
node_id: string
3334
up_node_id: string
3435
is_end: boolean
@@ -43,12 +44,16 @@ interface chatType {
4344
problem_text: string
4445
answer_text: string
4546
buffer: Array<String>
46-
answer_text_list: Array<{
47-
content: string
48-
chat_record_id?: string
49-
runtime_node_id?: string
50-
child_node?: any
51-
}>
47+
answer_text_list: Array<
48+
Array<{
49+
content: string
50+
reasoning_content: string
51+
chat_record_id?: string
52+
runtime_node_id?: string
53+
child_node?: any
54+
real_node_id?: string
55+
}>
56+
>
5257
/**
5358
* 是否写入结束
5459
*/
@@ -83,6 +88,7 @@ interface WriteNodeInfo {
8388
answer_text_list_index: number
8489
current_up_node?: any
8590
divider_content?: Array<string>
91+
divider_reasoning_content?: Array<string>
8692
}
8793
export class ChatRecordManage {
8894
id?: any
@@ -105,20 +111,38 @@ export class ChatRecordManage {
105111
}
106112
append_answer(
107113
chunk_answer: string,
114+
reasoning_content: string,
108115
index?: number,
109116
chat_record_id?: string,
110117
runtime_node_id?: string,
111-
child_node?: any
118+
child_node?: any,
119+
real_node_id?: string
112120
) {
113-
const set_index = index != undefined ? index : this.chat.answer_text_list.length - 1
114-
const content = this.chat.answer_text_list[set_index]
115-
? this.chat.answer_text_list[set_index].content + chunk_answer
116-
: chunk_answer
117-
this.chat.answer_text_list[set_index] = {
118-
content: content,
119-
chat_record_id,
120-
runtime_node_id,
121-
child_node
121+
if (chunk_answer || reasoning_content) {
122+
const set_index = index != undefined ? index : this.chat.answer_text_list.length - 1
123+
let card_list = this.chat.answer_text_list[set_index]
124+
if (!card_list) {
125+
card_list = []
126+
this.chat.answer_text_list[set_index] = card_list
127+
}
128+
const answer_value = card_list.find((item) => item.real_node_id == real_node_id)
129+
const content = answer_value ? answer_value.content + chunk_answer : chunk_answer
130+
const _reasoning_content = answer_value
131+
? answer_value.reasoning_content + reasoning_content
132+
: reasoning_content
133+
if (answer_value) {
134+
answer_value.content = content
135+
answer_value.reasoning_content = _reasoning_content
136+
} else {
137+
card_list.push({
138+
content: content,
139+
reasoning_content: _reasoning_content,
140+
chat_record_id,
141+
runtime_node_id,
142+
child_node,
143+
real_node_id
144+
})
145+
}
122146
}
123147

124148
this.chat.answer_text = this.chat.answer_text + chunk_answer
@@ -155,7 +179,7 @@ export class ChatRecordManage {
155179
) {
156180
const none_index = this.findIndex(
157181
this.chat.answer_text_list,
158-
(item) => item.content == '',
182+
(item) => (item.length == 1 && item[0].content == '') || item.length == 0,
159183
'index'
160184
)
161185
if (none_index > -1) {
@@ -166,7 +190,7 @@ export class ChatRecordManage {
166190
} else {
167191
const none_index = this.findIndex(
168192
this.chat.answer_text_list,
169-
(item) => item.content === '',
193+
(item) => (item.length == 1 && item[0].content == '') || item.length == 0,
170194
'index'
171195
)
172196
if (none_index > -1) {
@@ -178,10 +202,10 @@ export class ChatRecordManage {
178202

179203
this.write_node_info = {
180204
current_node: run_node,
181-
divider_content: ['\n\n'],
182205
current_up_node: current_up_node,
183206
answer_text_list_index: answer_text_list_index
184207
}
208+
185209
return this.write_node_info
186210
}
187211
return undefined
@@ -210,7 +234,7 @@ export class ChatRecordManage {
210234
}
211235
const last_index = this.findIndex(
212236
this.chat.answer_text_list,
213-
(item) => item.content == '',
237+
(item) => (item.length == 1 && item[0].content == '') || item.length == 0,
214238
'last'
215239
)
216240
if (last_index > 0) {
@@ -234,20 +258,29 @@ export class ChatRecordManage {
234258
}
235259
return
236260
}
237-
const { current_node, answer_text_list_index, divider_content } = node_info
261+
const { current_node, answer_text_list_index } = node_info
262+
238263
if (current_node.buffer.length > 20) {
239264
const context = current_node.is_end
240265
? current_node.buffer.splice(0)
241266
: current_node.buffer.splice(
242267
0,
243268
current_node.is_end ? undefined : current_node.buffer.length - 20
244269
)
270+
const reasoning_content = current_node.is_end
271+
? current_node.reasoning_content_buffer.splice(0)
272+
: current_node.reasoning_content_buffer.splice(
273+
0,
274+
current_node.is_end ? undefined : current_node.reasoning_content_buffer.length - 20
275+
)
245276
this.append_answer(
246-
(divider_content ? divider_content.splice(0).join('') : '') + context.join(''),
277+
context.join(''),
278+
reasoning_content.join(''),
247279
answer_text_list_index,
248280
current_node.chat_record_id,
249281
current_node.runtime_node_id,
250-
current_node.child_node
282+
current_node.child_node,
283+
current_node.real_node_id
251284
)
252285
} else if (this.is_close) {
253286
while (true) {
@@ -259,25 +292,44 @@ export class ChatRecordManage {
259292
this.append_answer(
260293
(node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') +
261294
node_info.current_node.buffer.splice(0).join(''),
295+
(node_info.divider_reasoning_content
296+
? node_info.divider_reasoning_content.splice(0).join('')
297+
: '') + node_info.current_node.reasoning_content_buffer.splice(0).join(''),
262298
node_info.answer_text_list_index,
263299
node_info.current_node.chat_record_id,
264300
node_info.current_node.runtime_node_id,
265-
node_info.current_node.child_node
301+
node_info.current_node.child_node,
302+
node_info.current_node.real_node_id
266303
)
304+
267305
if (node_info.current_node.buffer.length == 0) {
268306
node_info.current_node.is_end = true
269307
}
270308
}
271309
this.closeInterval()
272310
} else {
273311
const s = current_node.buffer.shift()
312+
const reasoning_content = current_node.reasoning_content_buffer.shift()
274313
if (s !== undefined) {
275314
this.append_answer(
276-
(divider_content ? divider_content.splice(0).join('') : '') + s,
315+
s,
316+
'',
277317
answer_text_list_index,
278318
current_node.chat_record_id,
279319
current_node.runtime_node_id,
280-
current_node.child_node
320+
current_node.child_node,
321+
current_node.real_node_id
322+
)
323+
}
324+
if (reasoning_content !== undefined) {
325+
this.append_answer(
326+
'',
327+
reasoning_content,
328+
answer_text_list_index,
329+
current_node.chat_record_id,
330+
current_node.runtime_node_id,
331+
current_node.child_node,
332+
current_node.real_node_id
281333
)
282334
}
283335
}
@@ -303,9 +355,15 @@ export class ChatRecordManage {
303355
if (n) {
304356
n.buffer.push(...chunk.content)
305357
n.content += chunk.content
358+
if (chunk.reasoning_content) {
359+
n.reasoning_content_buffer.push(...chunk.reasoning_content)
360+
n.reasoning_content += chunk.reasoning_content
361+
}
306362
} else {
307363
n = {
308364
buffer: [...chunk.content],
365+
reasoning_content_buffer: chunk.reasoning_content ? [...chunk.reasoning_content] : [],
366+
reasoning_content: chunk.reasoning_content ? chunk.reasoning_content : '',
309367
content: chunk.content,
310368
real_node_id: chunk.real_node_id,
311369
node_id: chunk.node_id,
@@ -324,13 +382,18 @@ export class ChatRecordManage {
324382
n['is_end'] = true
325383
}
326384
}
327-
append(answer_text_block: string) {
385+
append(answer_text_block: string, reasoning_content?: string) {
328386
let set_index = this.findIndex(
329387
this.chat.answer_text_list,
330-
(item) => item.content == '',
388+
(item) => item.length == 1 && item[0].content == '',
331389
'index'
332390
)
333-
this.chat.answer_text_list[set_index] = { content: answer_text_block }
391+
this.chat.answer_text_list[set_index] = [
392+
{
393+
content: answer_text_block,
394+
reasoning_content: reasoning_content ? reasoning_content : ''
395+
}
396+
]
334397
}
335398
}
336399

@@ -346,10 +409,10 @@ export class ChatManagement {
346409
chatRecord.appendChunk(chunk)
347410
}
348411
}
349-
static append(chatRecordId: string, content: string) {
412+
static append(chatRecordId: string, content: string, reasoning_content?: string) {
350413
const chatRecord = this.chatMessageContainer[chatRecordId]
351414
if (chatRecord) {
352-
chatRecord.append(content)
415+
chatRecord.append(content, reasoning_content)
353416
}
354417
}
355418
static updateStatus(chatRecordId: string, code: number) {

0 commit comments

Comments
 (0)