|
6 | 6 | import warnings |
7 | 7 | from concurrent.futures import ThreadPoolExecutor, Future |
8 | 8 | from datetime import datetime |
9 | | -from typing import Any, List, Optional, Union, Dict |
| 9 | +from typing import Any, List, Optional, Union, Dict, Iterator |
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | import orjson |
@@ -259,22 +259,14 @@ def generate_analysis(self): |
259 | 259 | in analysis_msg]) |
260 | 260 | full_thinking_text = '' |
261 | 261 | full_analysis_text = '' |
262 | | - res = self.llm.stream(analysis_msg) |
263 | 262 | token_usage = {} |
| 263 | + res = process_stream(self.llm.stream(analysis_msg), token_usage) |
264 | 264 | for chunk in res: |
265 | | - SQLBotLogUtil.info(chunk) |
266 | | - reasoning_content_chunk = '' |
267 | | - if 'reasoning_content' in chunk.additional_kwargs: |
268 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
269 | | - # else: |
270 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
271 | | - if reasoning_content_chunk is None: |
272 | | - reasoning_content_chunk = '' |
273 | | - full_thinking_text += reasoning_content_chunk |
274 | | - |
275 | | - full_analysis_text += chunk.content |
276 | | - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
277 | | - get_token_usage(chunk, token_usage) |
| 265 | + if chunk.get('content'): |
| 266 | + full_analysis_text += chunk.get('content') |
| 267 | + if chunk.get('reasoning_content'): |
| 268 | + full_thinking_text += chunk.get('reasoning_content') |
| 269 | + yield chunk |
278 | 270 |
|
279 | 271 | analysis_msg.append(AIMessage(full_analysis_text)) |
280 | 272 |
|
@@ -311,22 +303,14 @@ def generate_predict(self): |
311 | 303 | in predict_msg]) |
312 | 304 | full_thinking_text = '' |
313 | 305 | full_predict_text = '' |
314 | | - res = self.llm.stream(predict_msg) |
315 | 306 | token_usage = {} |
| 307 | + res = process_stream(self.llm.stream(predict_msg), token_usage) |
316 | 308 | for chunk in res: |
317 | | - SQLBotLogUtil.info(chunk) |
318 | | - reasoning_content_chunk = '' |
319 | | - if 'reasoning_content' in chunk.additional_kwargs: |
320 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
321 | | - # else: |
322 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
323 | | - if reasoning_content_chunk is None: |
324 | | - reasoning_content_chunk = '' |
325 | | - full_thinking_text += reasoning_content_chunk |
326 | | - |
327 | | - full_predict_text += chunk.content |
328 | | - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
329 | | - get_token_usage(chunk, token_usage) |
| 309 | + if chunk.get('content'): |
| 310 | + full_predict_text += chunk.get('content') |
| 311 | + if chunk.get('reasoning_content'): |
| 312 | + full_thinking_text += chunk.get('reasoning_content') |
| 313 | + yield chunk |
330 | 314 |
|
331 | 315 | predict_msg.append(AIMessage(full_predict_text)) |
332 | 316 | self.record = save_predict_answer(session=self.session, record_id=self.record.id, |
@@ -370,21 +354,13 @@ def generate_recommend_questions_task(self): |
370 | 354 | full_thinking_text = '' |
371 | 355 | full_guess_text = '' |
372 | 356 | token_usage = {} |
373 | | - res = self.llm.stream(guess_msg) |
| 357 | + res = process_stream(self.llm.stream(guess_msg), token_usage) |
374 | 358 | for chunk in res: |
375 | | - SQLBotLogUtil.info(chunk) |
376 | | - reasoning_content_chunk = '' |
377 | | - if 'reasoning_content' in chunk.additional_kwargs: |
378 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
379 | | - # else: |
380 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
381 | | - if reasoning_content_chunk is None: |
382 | | - reasoning_content_chunk = '' |
383 | | - full_thinking_text += reasoning_content_chunk |
384 | | - |
385 | | - full_guess_text += chunk.content |
386 | | - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
387 | | - get_token_usage(chunk, token_usage) |
| 359 | + if chunk.get('content'): |
| 360 | + full_guess_text += chunk.get('content') |
| 361 | + if chunk.get('reasoning_content'): |
| 362 | + full_thinking_text += chunk.get('reasoning_content') |
| 363 | + yield chunk |
388 | 364 |
|
389 | 365 | guess_msg.append(AIMessage(full_guess_text)) |
390 | 366 |
|
@@ -450,21 +426,13 @@ def select_datasource(self): |
450 | 426 | msg in datasource_msg]) |
451 | 427 |
|
452 | 428 | token_usage = {} |
453 | | - res = self.llm.stream(datasource_msg) |
| 429 | + res = process_stream(self.llm.stream(datasource_msg), token_usage) |
454 | 430 | for chunk in res: |
455 | | - SQLBotLogUtil.info(chunk) |
456 | | - reasoning_content_chunk = '' |
457 | | - if 'reasoning_content' in chunk.additional_kwargs: |
458 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
459 | | - # else: |
460 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
461 | | - if reasoning_content_chunk is None: |
462 | | - reasoning_content_chunk = '' |
463 | | - full_thinking_text += reasoning_content_chunk |
464 | | - |
465 | | - full_text += chunk.content |
466 | | - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
467 | | - get_token_usage(chunk, token_usage) |
| 431 | + if chunk.get('content'): |
| 432 | + full_text += chunk.get('content') |
| 433 | + if chunk.get('reasoning_content'): |
| 434 | + full_thinking_text += chunk.get('reasoning_content') |
| 435 | + yield chunk |
468 | 436 | datasource_msg.append(AIMessage(full_text)) |
469 | 437 |
|
470 | 438 | self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, |
@@ -560,21 +528,13 @@ def generate_sql(self): |
560 | 528 | full_thinking_text = '' |
561 | 529 | full_sql_text = '' |
562 | 530 | token_usage = {} |
563 | | - res = self.llm.stream(self.sql_message) |
| 531 | + res = process_stream(self.llm.stream(self.sql_message), token_usage) |
564 | 532 | for chunk in res: |
565 | | - SQLBotLogUtil.info(chunk) |
566 | | - reasoning_content_chunk = '' |
567 | | - if 'reasoning_content' in chunk.additional_kwargs: |
568 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
569 | | - # else: |
570 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
571 | | - if reasoning_content_chunk is None: |
572 | | - reasoning_content_chunk = '' |
573 | | - full_thinking_text += reasoning_content_chunk |
574 | | - |
575 | | - full_sql_text += chunk.content |
576 | | - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
577 | | - get_token_usage(chunk, token_usage) |
| 533 | + if chunk.get('content'): |
| 534 | + full_sql_text += chunk.get('content') |
| 535 | + if chunk.get('reasoning_content'): |
| 536 | + full_thinking_text += chunk.get('reasoning_content') |
| 537 | + yield chunk |
578 | 538 |
|
579 | 539 | self.sql_message.append(AIMessage(full_sql_text)) |
580 | 540 |
|
@@ -607,18 +567,14 @@ def generate_with_sub_sql(self, sql, sub_mappings: list): |
607 | 567 |
|
608 | 568 | full_thinking_text = '' |
609 | 569 | full_dynamic_text = '' |
610 | | - res = self.llm.stream(dynamic_sql_msg) |
611 | 570 | token_usage = {} |
| 571 | + res = process_stream(self.llm.stream(dynamic_sql_msg), token_usage) |
612 | 572 | for chunk in res: |
613 | | - SQLBotLogUtil.info(chunk) |
614 | | - reasoning_content_chunk = '' |
615 | | - if 'reasoning_content' in chunk.additional_kwargs: |
616 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
617 | | - if reasoning_content_chunk is None: |
618 | | - reasoning_content_chunk = '' |
619 | | - full_thinking_text += reasoning_content_chunk |
620 | | - full_dynamic_text += chunk.content |
621 | | - get_token_usage(chunk, token_usage) |
| 573 | + if chunk.get('content'): |
| 574 | + full_dynamic_text += chunk.get('content') |
| 575 | + if chunk.get('reasoning_content'): |
| 576 | + full_thinking_text += chunk.get('reasoning_content') |
| 577 | + yield chunk |
622 | 578 |
|
623 | 579 | dynamic_sql_msg.append(AIMessage(full_dynamic_text)) |
624 | 580 |
|
@@ -670,22 +626,13 @@ def build_table_filter(self, sql: str, filters: list): |
670 | 626 | in permission_sql_msg]) |
671 | 627 | full_thinking_text = '' |
672 | 628 | full_filter_text = '' |
673 | | - res = self.llm.stream(permission_sql_msg) |
674 | 629 | token_usage = {} |
| 630 | + res = process_stream(self.llm.stream(permission_sql_msg), token_usage) |
675 | 631 | for chunk in res: |
676 | | - SQLBotLogUtil.info(chunk) |
677 | | - reasoning_content_chunk = '' |
678 | | - if 'reasoning_content' in chunk.additional_kwargs: |
679 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
680 | | - # else: |
681 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
682 | | - if reasoning_content_chunk is None: |
683 | | - reasoning_content_chunk = '' |
684 | | - full_thinking_text += reasoning_content_chunk |
685 | | - |
686 | | - full_filter_text += chunk.content |
687 | | - # yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
688 | | - get_token_usage(chunk, token_usage) |
| 632 | + if chunk.get('content'): |
| 633 | + full_filter_text += chunk.get('content') |
| 634 | + if chunk.get('reasoning_content'): |
| 635 | + full_thinking_text += chunk.get('reasoning_content') |
689 | 636 |
|
690 | 637 | permission_sql_msg.append(AIMessage(full_filter_text)) |
691 | 638 |
|
@@ -735,21 +682,13 @@ def generate_chart(self, chart_type: Optional[str] = ''): |
735 | 682 | full_thinking_text = '' |
736 | 683 | full_chart_text = '' |
737 | 684 | token_usage = {} |
738 | | - res = self.llm.stream(self.chart_message) |
| 685 | + res = process_stream(self.llm.stream(self.chart_message), token_usage) |
739 | 686 | for chunk in res: |
740 | | - SQLBotLogUtil.info(chunk) |
741 | | - reasoning_content_chunk = '' |
742 | | - if 'reasoning_content' in chunk.additional_kwargs: |
743 | | - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
744 | | - # else: |
745 | | - # reasoning_content_chunk = chunk.get('reasoning_content') |
746 | | - if reasoning_content_chunk is None: |
747 | | - reasoning_content_chunk = '' |
748 | | - full_thinking_text += reasoning_content_chunk |
749 | | - |
750 | | - full_chart_text += chunk.content |
751 | | - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} |
752 | | - get_token_usage(chunk, token_usage) |
| 687 | + if chunk.get('content'): |
| 688 | + full_chart_text += chunk.get('content') |
| 689 | + if chunk.get('reasoning_content'): |
| 690 | + full_thinking_text += chunk.get('reasoning_content') |
| 691 | + yield chunk |
753 | 692 |
|
754 | 693 | self.chart_message.append(AIMessage(full_chart_text)) |
755 | 694 |
|
@@ -1053,7 +992,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True, |
1053 | 992 | else: |
1054 | 993 | sql = self.check_save_sql(res=full_sql_text) |
1055 | 994 |
|
1056 | | - SQLBotLogUtil.info(sql) |
| 995 | + SQLBotLogUtil.info('sql: ' + sql) |
1057 | 996 |
|
1058 | 997 | if not stream: |
1059 | 998 | json_result['sql'] = sql |
@@ -1372,16 +1311,116 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict): |
1372 | 1311 | return request_path |
1373 | 1312 |
|
1374 | 1313 |
|
1375 | | -def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}): |
| 1314 | +def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None): |
1376 | 1315 | try: |
1377 | 1316 | if chunk.usage_metadata: |
| 1317 | + if token_usage is None: |
| 1318 | + token_usage = {} |
1378 | 1319 | token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens') |
1379 | 1320 | token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens') |
1380 | 1321 | token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens') |
1381 | 1322 | except Exception: |
1382 | 1323 | pass |
1383 | 1324 |
|
1384 | 1325 |
|
| 1326 | +def process_stream(res: Iterator[BaseMessageChunk], |
| 1327 | + token_usage: Dict[str, Any] = None, |
| 1328 | + enable_tag_parsing: bool = settings.PARSE_REASONING_BLOCK_ENABLED, |
| 1329 | + start_tag: str = settings.DEFAULT_REASONING_CONTENT_START, |
| 1330 | + end_tag: str = settings.DEFAULT_REASONING_CONTENT_END |
| 1331 | + ): |
| 1332 | + if token_usage is None: |
| 1333 | + token_usage = {} |
| 1334 | + in_thinking_block = False # 标记是否在思考过程块中 |
| 1335 | + current_thinking = '' # 当前收集的思考过程内容 |
| 1336 | + pending_start_tag = '' # 用于缓存可能被截断的开始标签部分 |
| 1337 | + |
| 1338 | + for chunk in res: |
| 1339 | + SQLBotLogUtil.info(chunk) |
| 1340 | + reasoning_content_chunk = '' |
| 1341 | + content = chunk.content |
| 1342 | + output_content = '' # 实际要输出的内容 |
| 1343 | + |
| 1344 | + # 检查additional_kwargs中的reasoning_content |
| 1345 | + if 'reasoning_content' in chunk.additional_kwargs: |
| 1346 | + reasoning_content = chunk.additional_kwargs.get('reasoning_content', '') |
| 1347 | + if reasoning_content is None: |
| 1348 | + reasoning_content = '' |
| 1349 | + |
| 1350 | + # 累积additional_kwargs中的思考内容到current_thinking |
| 1351 | + current_thinking += reasoning_content |
| 1352 | + reasoning_content_chunk = reasoning_content |
| 1353 | + |
| 1354 | + # 只有当current_thinking不是空字符串时才跳过标签解析 |
| 1355 | + if not in_thinking_block and current_thinking.strip() != '': |
| 1356 | + output_content = content # 正常输出content |
| 1357 | + yield { |
| 1358 | + 'content': output_content, |
| 1359 | + 'reasoning_content': reasoning_content_chunk |
| 1360 | + } |
| 1361 | + get_token_usage(chunk, token_usage) |
| 1362 | + continue # 跳过后续的标签解析逻辑 |
| 1363 | + |
| 1364 | + # 如果没有有效的思考内容,并且启用了标签解析,才执行标签解析逻辑 |
| 1365 | + # 如果有缓存的开始标签部分,先拼接当前内容 |
| 1366 | + if pending_start_tag: |
| 1367 | + content = pending_start_tag + content |
| 1368 | + pending_start_tag = '' |
| 1369 | + |
| 1370 | + # 检查是否开始思考过程块(处理可能被截断的开始标签) |
| 1371 | + if enable_tag_parsing and not in_thinking_block and start_tag: |
| 1372 | + if start_tag in content: |
| 1373 | + start_idx = content.index(start_tag) |
| 1374 | + # 只有当开始标签前面没有其他文本时才认为是真正的思考块开始 |
| 1375 | + if start_idx == 0 or content[:start_idx].strip() == '': |
| 1376 | + # 完整标签存在且前面没有其他文本 |
| 1377 | + output_content += content[:start_idx] # 输出开始标签之前的内容 |
| 1378 | + content = content[start_idx + len(start_tag):] # 移除开始标签 |
| 1379 | + in_thinking_block = True |
| 1380 | + else: |
| 1381 | + # 开始标签前面有其他文本,不认为是思考块开始 |
| 1382 | + output_content += content |
| 1383 | + content = '' |
| 1384 | + else: |
| 1385 | + # 检查是否可能有部分开始标签 |
| 1386 | + for i in range(1, len(start_tag)): |
| 1387 | + if content.endswith(start_tag[:i]): |
| 1388 | + # 只有当当前内容全是空白时才缓存部分标签 |
| 1389 | + if content[:-i].strip() == '': |
| 1390 | + pending_start_tag = start_tag[:i] |
| 1391 | + content = content[:-i] # 移除可能的部分标签 |
| 1392 | + output_content += content |
| 1393 | + content = '' |
| 1394 | + break |
| 1395 | + |
| 1396 | + # 处理思考块内容 |
| 1397 | + if enable_tag_parsing and in_thinking_block and end_tag: |
| 1398 | + if end_tag in content: |
| 1399 | + # 找到结束标签 |
| 1400 | + end_idx = content.index(end_tag) |
| 1401 | + current_thinking += content[:end_idx] # 收集思考内容 |
| 1402 | + reasoning_content_chunk += current_thinking # 添加到当前块的思考内容 |
| 1403 | + content = content[end_idx + len(end_tag):] # 移除结束标签后的内容 |
| 1404 | + current_thinking = '' # 重置当前思考内容 |
| 1405 | + in_thinking_block = False |
| 1406 | + output_content += content # 输出结束标签之后的内容 |
| 1407 | + else: |
| 1408 | + # 在遇到结束标签前,持续收集思考内容 |
| 1409 | + current_thinking += content |
| 1410 | + reasoning_content_chunk += content |
| 1411 | + content = '' |
| 1412 | + |
| 1413 | + else: |
| 1414 | + # 不在思考块中或标签解析未启用,正常输出 |
| 1415 | + output_content += content |
| 1416 | + |
| 1417 | + yield { |
| 1418 | + 'content': output_content, |
| 1419 | + 'reasoning_content': reasoning_content_chunk |
| 1420 | + } |
| 1421 | + get_token_usage(chunk, token_usage) |
| 1422 | + |
| 1423 | + |
1385 | 1424 | def get_lang_name(lang: str): |
1386 | 1425 | if lang and lang == 'en': |
1387 | 1426 | return '英文' |
|
0 commit comments