Skip to content

Commit cc150c3

Browse files
committed
feat: support parse <think> reasoning block
1 parent aa0218d commit cc150c3

File tree

2 files changed

+154
-111
lines changed

2 files changed

+154
-111
lines changed

backend/apps/chat/task/llm.py

Lines changed: 150 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import warnings
77
from concurrent.futures import ThreadPoolExecutor, Future
88
from datetime import datetime
9-
from typing import Any, List, Optional, Union, Dict
9+
from typing import Any, List, Optional, Union, Dict, Iterator
1010

1111
import numpy as np
1212
import orjson
@@ -259,22 +259,14 @@ def generate_analysis(self):
259259
in analysis_msg])
260260
full_thinking_text = ''
261261
full_analysis_text = ''
262-
res = self.llm.stream(analysis_msg)
263262
token_usage = {}
263+
res = process_stream(self.llm.stream(analysis_msg), token_usage)
264264
for chunk in res:
265-
SQLBotLogUtil.info(chunk)
266-
reasoning_content_chunk = ''
267-
if 'reasoning_content' in chunk.additional_kwargs:
268-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
269-
# else:
270-
# reasoning_content_chunk = chunk.get('reasoning_content')
271-
if reasoning_content_chunk is None:
272-
reasoning_content_chunk = ''
273-
full_thinking_text += reasoning_content_chunk
274-
275-
full_analysis_text += chunk.content
276-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
277-
get_token_usage(chunk, token_usage)
265+
if chunk.get('content'):
266+
full_analysis_text += chunk.get('content')
267+
if chunk.get('reasoning_content'):
268+
full_thinking_text += chunk.get('reasoning_content')
269+
yield chunk
278270

279271
analysis_msg.append(AIMessage(full_analysis_text))
280272

@@ -311,22 +303,14 @@ def generate_predict(self):
311303
in predict_msg])
312304
full_thinking_text = ''
313305
full_predict_text = ''
314-
res = self.llm.stream(predict_msg)
315306
token_usage = {}
307+
res = process_stream(self.llm.stream(predict_msg), token_usage)
316308
for chunk in res:
317-
SQLBotLogUtil.info(chunk)
318-
reasoning_content_chunk = ''
319-
if 'reasoning_content' in chunk.additional_kwargs:
320-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
321-
# else:
322-
# reasoning_content_chunk = chunk.get('reasoning_content')
323-
if reasoning_content_chunk is None:
324-
reasoning_content_chunk = ''
325-
full_thinking_text += reasoning_content_chunk
326-
327-
full_predict_text += chunk.content
328-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
329-
get_token_usage(chunk, token_usage)
309+
if chunk.get('content'):
310+
full_predict_text += chunk.get('content')
311+
if chunk.get('reasoning_content'):
312+
full_thinking_text += chunk.get('reasoning_content')
313+
yield chunk
330314

331315
predict_msg.append(AIMessage(full_predict_text))
332316
self.record = save_predict_answer(session=self.session, record_id=self.record.id,
@@ -370,21 +354,13 @@ def generate_recommend_questions_task(self):
370354
full_thinking_text = ''
371355
full_guess_text = ''
372356
token_usage = {}
373-
res = self.llm.stream(guess_msg)
357+
res = process_stream(self.llm.stream(guess_msg), token_usage)
374358
for chunk in res:
375-
SQLBotLogUtil.info(chunk)
376-
reasoning_content_chunk = ''
377-
if 'reasoning_content' in chunk.additional_kwargs:
378-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
379-
# else:
380-
# reasoning_content_chunk = chunk.get('reasoning_content')
381-
if reasoning_content_chunk is None:
382-
reasoning_content_chunk = ''
383-
full_thinking_text += reasoning_content_chunk
384-
385-
full_guess_text += chunk.content
386-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
387-
get_token_usage(chunk, token_usage)
359+
if chunk.get('content'):
360+
full_guess_text += chunk.get('content')
361+
if chunk.get('reasoning_content'):
362+
full_thinking_text += chunk.get('reasoning_content')
363+
yield chunk
388364

389365
guess_msg.append(AIMessage(full_guess_text))
390366

@@ -450,21 +426,13 @@ def select_datasource(self):
450426
msg in datasource_msg])
451427

452428
token_usage = {}
453-
res = self.llm.stream(datasource_msg)
429+
res = process_stream(self.llm.stream(datasource_msg), token_usage)
454430
for chunk in res:
455-
SQLBotLogUtil.info(chunk)
456-
reasoning_content_chunk = ''
457-
if 'reasoning_content' in chunk.additional_kwargs:
458-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
459-
# else:
460-
# reasoning_content_chunk = chunk.get('reasoning_content')
461-
if reasoning_content_chunk is None:
462-
reasoning_content_chunk = ''
463-
full_thinking_text += reasoning_content_chunk
464-
465-
full_text += chunk.content
466-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
467-
get_token_usage(chunk, token_usage)
431+
if chunk.get('content'):
432+
full_text += chunk.get('content')
433+
if chunk.get('reasoning_content'):
434+
full_thinking_text += chunk.get('reasoning_content')
435+
yield chunk
468436
datasource_msg.append(AIMessage(full_text))
469437

470438
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
@@ -560,21 +528,13 @@ def generate_sql(self):
560528
full_thinking_text = ''
561529
full_sql_text = ''
562530
token_usage = {}
563-
res = self.llm.stream(self.sql_message)
531+
res = process_stream(self.llm.stream(self.sql_message), token_usage)
564532
for chunk in res:
565-
SQLBotLogUtil.info(chunk)
566-
reasoning_content_chunk = ''
567-
if 'reasoning_content' in chunk.additional_kwargs:
568-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
569-
# else:
570-
# reasoning_content_chunk = chunk.get('reasoning_content')
571-
if reasoning_content_chunk is None:
572-
reasoning_content_chunk = ''
573-
full_thinking_text += reasoning_content_chunk
574-
575-
full_sql_text += chunk.content
576-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
577-
get_token_usage(chunk, token_usage)
533+
if chunk.get('content'):
534+
full_sql_text += chunk.get('content')
535+
if chunk.get('reasoning_content'):
536+
full_thinking_text += chunk.get('reasoning_content')
537+
yield chunk
578538

579539
self.sql_message.append(AIMessage(full_sql_text))
580540

@@ -607,18 +567,14 @@ def generate_with_sub_sql(self, sql, sub_mappings: list):
607567

608568
full_thinking_text = ''
609569
full_dynamic_text = ''
610-
res = self.llm.stream(dynamic_sql_msg)
611570
token_usage = {}
571+
res = process_stream(self.llm.stream(dynamic_sql_msg), token_usage)
612572
for chunk in res:
613-
SQLBotLogUtil.info(chunk)
614-
reasoning_content_chunk = ''
615-
if 'reasoning_content' in chunk.additional_kwargs:
616-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
617-
if reasoning_content_chunk is None:
618-
reasoning_content_chunk = ''
619-
full_thinking_text += reasoning_content_chunk
620-
full_dynamic_text += chunk.content
621-
get_token_usage(chunk, token_usage)
573+
if chunk.get('content'):
574+
full_dynamic_text += chunk.get('content')
575+
if chunk.get('reasoning_content'):
576+
full_thinking_text += chunk.get('reasoning_content')
577+
yield chunk
622578

623579
dynamic_sql_msg.append(AIMessage(full_dynamic_text))
624580

@@ -670,22 +626,13 @@ def build_table_filter(self, sql: str, filters: list):
670626
in permission_sql_msg])
671627
full_thinking_text = ''
672628
full_filter_text = ''
673-
res = self.llm.stream(permission_sql_msg)
674629
token_usage = {}
630+
res = process_stream(self.llm.stream(permission_sql_msg), token_usage)
675631
for chunk in res:
676-
SQLBotLogUtil.info(chunk)
677-
reasoning_content_chunk = ''
678-
if 'reasoning_content' in chunk.additional_kwargs:
679-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
680-
# else:
681-
# reasoning_content_chunk = chunk.get('reasoning_content')
682-
if reasoning_content_chunk is None:
683-
reasoning_content_chunk = ''
684-
full_thinking_text += reasoning_content_chunk
685-
686-
full_filter_text += chunk.content
687-
# yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
688-
get_token_usage(chunk, token_usage)
632+
if chunk.get('content'):
633+
full_filter_text += chunk.get('content')
634+
if chunk.get('reasoning_content'):
635+
full_thinking_text += chunk.get('reasoning_content')
689636

690637
permission_sql_msg.append(AIMessage(full_filter_text))
691638

@@ -735,21 +682,13 @@ def generate_chart(self, chart_type: Optional[str] = ''):
735682
full_thinking_text = ''
736683
full_chart_text = ''
737684
token_usage = {}
738-
res = self.llm.stream(self.chart_message)
685+
res = process_stream(self.llm.stream(self.chart_message), token_usage)
739686
for chunk in res:
740-
SQLBotLogUtil.info(chunk)
741-
reasoning_content_chunk = ''
742-
if 'reasoning_content' in chunk.additional_kwargs:
743-
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
744-
# else:
745-
# reasoning_content_chunk = chunk.get('reasoning_content')
746-
if reasoning_content_chunk is None:
747-
reasoning_content_chunk = ''
748-
full_thinking_text += reasoning_content_chunk
749-
750-
full_chart_text += chunk.content
751-
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
752-
get_token_usage(chunk, token_usage)
687+
if chunk.get('content'):
688+
full_chart_text += chunk.get('content')
689+
if chunk.get('reasoning_content'):
690+
full_thinking_text += chunk.get('reasoning_content')
691+
yield chunk
753692

754693
self.chart_message.append(AIMessage(full_chart_text))
755694

@@ -1053,7 +992,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
1053992
else:
1054993
sql = self.check_save_sql(res=full_sql_text)
1055994

1056-
SQLBotLogUtil.info(sql)
995+
SQLBotLogUtil.info('sql: ' + sql)
1057996

1058997
if not stream:
1059998
json_result['sql'] = sql
@@ -1372,16 +1311,116 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
13721311
return request_path
13731312

13741313

1375-
def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
1314+
def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None):
13761315
try:
13771316
if chunk.usage_metadata:
1317+
if token_usage is None:
1318+
token_usage = {}
13781319
token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens')
13791320
token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens')
13801321
token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens')
13811322
except Exception:
13821323
pass
13831324

13841325

1326+
def process_stream(res: Iterator[BaseMessageChunk],
1327+
token_usage: Dict[str, Any] = None,
1328+
enable_tag_parsing: bool = settings.PARSE_REASONING_BLOCK_ENABLED,
1329+
start_tag: str = settings.DEFAULT_REASONING_CONTENT_START,
1330+
end_tag: str = settings.DEFAULT_REASONING_CONTENT_END
1331+
):
1332+
if token_usage is None:
1333+
token_usage = {}
1334+
in_thinking_block = False # 标记是否在思考过程块中
1335+
current_thinking = '' # 当前收集的思考过程内容
1336+
pending_start_tag = '' # 用于缓存可能被截断的开始标签部分
1337+
1338+
for chunk in res:
1339+
SQLBotLogUtil.info(chunk)
1340+
reasoning_content_chunk = ''
1341+
content = chunk.content
1342+
output_content = '' # 实际要输出的内容
1343+
1344+
# 检查additional_kwargs中的reasoning_content
1345+
if 'reasoning_content' in chunk.additional_kwargs:
1346+
reasoning_content = chunk.additional_kwargs.get('reasoning_content', '')
1347+
if reasoning_content is None:
1348+
reasoning_content = ''
1349+
1350+
# 累积additional_kwargs中的思考内容到current_thinking
1351+
current_thinking += reasoning_content
1352+
reasoning_content_chunk = reasoning_content
1353+
1354+
# 只有当current_thinking不是空字符串时才跳过标签解析
1355+
if not in_thinking_block and current_thinking.strip() != '':
1356+
output_content = content # 正常输出content
1357+
yield {
1358+
'content': output_content,
1359+
'reasoning_content': reasoning_content_chunk
1360+
}
1361+
get_token_usage(chunk, token_usage)
1362+
continue # 跳过后续的标签解析逻辑
1363+
1364+
# 如果没有有效的思考内容,并且启用了标签解析,才执行标签解析逻辑
1365+
# 如果有缓存的开始标签部分,先拼接当前内容
1366+
if pending_start_tag:
1367+
content = pending_start_tag + content
1368+
pending_start_tag = ''
1369+
1370+
# 检查是否开始思考过程块(处理可能被截断的开始标签)
1371+
if enable_tag_parsing and not in_thinking_block and start_tag:
1372+
if start_tag in content:
1373+
start_idx = content.index(start_tag)
1374+
# 只有当开始标签前面没有其他文本时才认为是真正的思考块开始
1375+
if start_idx == 0 or content[:start_idx].strip() == '':
1376+
# 完整标签存在且前面没有其他文本
1377+
output_content += content[:start_idx] # 输出开始标签之前的内容
1378+
content = content[start_idx + len(start_tag):] # 移除开始标签
1379+
in_thinking_block = True
1380+
else:
1381+
# 开始标签前面有其他文本,不认为是思考块开始
1382+
output_content += content
1383+
content = ''
1384+
else:
1385+
# 检查是否可能有部分开始标签
1386+
for i in range(1, len(start_tag)):
1387+
if content.endswith(start_tag[:i]):
1388+
# 只有当当前内容全是空白时才缓存部分标签
1389+
if content[:-i].strip() == '':
1390+
pending_start_tag = start_tag[:i]
1391+
content = content[:-i] # 移除可能的部分标签
1392+
output_content += content
1393+
content = ''
1394+
break
1395+
1396+
# 处理思考块内容
1397+
if enable_tag_parsing and in_thinking_block and end_tag:
1398+
if end_tag in content:
1399+
# 找到结束标签
1400+
end_idx = content.index(end_tag)
1401+
current_thinking += content[:end_idx] # 收集思考内容
1402+
reasoning_content_chunk += current_thinking # 添加到当前块的思考内容
1403+
content = content[end_idx + len(end_tag):] # 移除结束标签后的内容
1404+
current_thinking = '' # 重置当前思考内容
1405+
in_thinking_block = False
1406+
output_content += content # 输出结束标签之后的内容
1407+
else:
1408+
# 在遇到结束标签前,持续收集思考内容
1409+
current_thinking += content
1410+
reasoning_content_chunk += content
1411+
content = ''
1412+
1413+
else:
1414+
# 不在思考块中或标签解析未启用,正常输出
1415+
output_content += content
1416+
1417+
yield {
1418+
'content': output_content,
1419+
'reasoning_content': reasoning_content_chunk
1420+
}
1421+
get_token_usage(chunk, token_usage)
1422+
1423+
13851424
def get_lang_name(lang: str):
13861425
if lang and lang == 'en':
13871426
return '英文'

backend/common/core/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
9696
EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
9797
EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
9898

99+
PARSE_REASONING_BLOCK_ENABLED: bool = True
100+
DEFAULT_REASONING_CONTENT_START: str = '<think>'
101+
DEFAULT_REASONING_CONTENT_END: str = '</think>'
102+
99103
PG_POOL_SIZE: int = 20
100104
PG_MAX_OVERFLOW: int = 30
101105
PG_POOL_RECYCLE: int = 3600

0 commit comments

Comments
 (0)