|
1 | 1 | import logging |
| 2 | +import traceback |
2 | 3 | import warnings |
3 | 4 | from typing import Any, List, Union, Dict |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import orjson |
7 | 8 | import pandas as pd |
| 9 | +import requests |
8 | 10 | from langchain_community.utilities import SQLDatabase |
9 | 11 | from langchain_core.language_models import BaseLLM |
10 | 12 | from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, AIMessageChunk |
|
22 | 24 | from apps.db.db import exec_sql |
23 | 25 | from apps.system.crud.user import get_user_info |
24 | 26 | from apps.system.models.system_model import AiModelDetail |
| 27 | +from common.core.config import settings |
25 | 28 | from common.core.deps import SessionDep, CurrentUser |
26 | 29 |
|
27 | 30 | warnings.filterwarnings("ignore") |
@@ -68,8 +71,9 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question |
68 | 71 |
|
69 | 72 | history_records: List[ChatRecord] = list( |
70 | 73 | map(lambda x: ChatRecord(**x.model_dump()), filter(lambda r: True if r.first_chat != True else False, |
71 | | - list_records(session=self.session, current_user=current_user, |
72 | | - chart_id=chat_question.chat_id)))) |
| 74 | + list_records(session=self.session, |
| 75 | + current_user=current_user, |
| 76 | + chart_id=chat_question.chat_id)))) |
73 | 77 | # get schema |
74 | 78 | if ds: |
75 | 79 | chat_question.db_schema = get_table_schema(session=self.session, ds=ds) |
@@ -639,12 +643,49 @@ def run_task(llm_service: LLMService, session: SessionDep, in_chat: bool = True) |
639 | 643 | else: |
640 | 644 | # todo generate picture |
641 | 645 | if chart['type'] != 'table': |
642 | | - yield '# todo generate chart picture' |
643 | | - |
644 | | - yield f'![{chart["type"]}](https://sqlbot.fit2cloud.cn/images/111.png)' |
| 646 | + yield '# generated chart picture' |
| 647 | + image_url = request_picture(llm_service.record.chat_id, llm_service.record.id, chart, result) |
| 648 | + print(image_url) |
| 649 | + yield f'![{chart["type"]}]({image_url})' |
645 | 650 | except Exception as e: |
| 651 | + traceback.print_exc() |
646 | 652 | llm_service.save_error(message=str(e)) |
647 | 653 | if in_chat: |
648 | 654 | yield orjson.dumps({'content': str(e), 'type': 'error'}).decode() + '\n\n' |
649 | 655 | else: |
650 | 656 | yield f'> ❌ **ERROR**\n\n> \n\n> {str(e)}。' |
| 657 | + |
| 658 | + |
| 659 | +def request_picture(chat_id: int, record_id: int, chart: dict, data: dict): |
| 660 | + file_name = f'c_{chat_id}_r_{record_id}' |
| 661 | + |
| 662 | + columns = chart.get('columns') if chart.get('columns') else [] |
| 663 | + x = None |
| 664 | + y = None |
| 665 | + series = None |
| 666 | + if chart.get('axis'): |
| 667 | + x = chart.get('axis').get('x') |
| 668 | + y = chart.get('axis').get('y') |
| 669 | + series = chart.get('axis').get('series') |
| 670 | + |
| 671 | + axis = [] |
| 672 | + for v in columns: |
| 673 | + axis.append({'name': v.get('name'), 'value': v.get('value')}) |
| 674 | + if x: |
| 675 | + axis.append({'name': x.get('name'), 'value': x.get('value'), 'type': 'x'}) |
| 676 | + if y: |
| 677 | + axis.append({'name': y.get('name'), 'value': y.get('value'), 'type': 'y'}) |
| 678 | + if series: |
| 679 | + axis.append({'name': series.get('name'), 'value': series.get('value'), 'type': 'series'}) |
| 680 | + |
| 681 | + request_obj = { |
| 682 | + "path": (settings.MCP_IMAGE_PATH if settings.MCP_IMAGE_PATH[-1] == '/' else ( |
| 683 | + settings.MCP_IMAGE_PATH + '/')) + file_name, |
| 684 | + "type": chart['type'], |
| 685 | + "data": orjson.dumps(data.get('data') if data.get('data') else []).decode(), |
| 686 | + "axis": orjson.dumps(axis).decode(), |
| 687 | + } |
| 688 | + |
| 689 | + requests.post(url=settings.MCP_IMAGE_HOST, json=request_obj) |
| 690 | + |
| 691 | + return f'{(settings.SERVER_IMAGE_HOST if settings.SERVER_IMAGE_HOST[-1] == "/" else (settings.SERVER_IMAGE_HOST + "/"))}{file_name}.png' |
0 commit comments