33import traceback
44import warnings
55from concurrent .futures import ThreadPoolExecutor , Future
6+ from datetime import datetime
67from typing import Any , List , Optional , Union , Dict
78
89import numpy as np
@@ -70,7 +71,8 @@ class LLMService:
7071 future : Future
7172
7273 def __init__ (self , current_user : CurrentUser , chat_question : ChatQuestion ,
73- current_assistant : Optional [CurrentAssistant ] = None , no_reasoning : bool = False , config : LLMConfig = None ):
74+ current_assistant : Optional [CurrentAssistant ] = None , no_reasoning : bool = False ,
75+ config : LLMConfig = None ):
7476 self .chunk_list = []
7577 engine = create_engine (str (settings .SQLALCHEMY_DATABASE_URI ))
7678 session_maker = sessionmaker (bind = engine )
@@ -126,7 +128,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
126128 self .llm = llm_instance .llm
127129
128130 self .init_messages ()
129-
131+
130132 @classmethod
131133 async def create (cls , * args , ** kwargs ):
132134 config : LLMConfig = await get_default_config ()
@@ -503,7 +505,8 @@ def select_datasource(self):
503505
504506 def generate_sql (self ):
505507 # append current question
506- self .sql_message .append (HumanMessage (self .chat_question .sql_user_question ()))
508+ self .sql_message .append (HumanMessage (
509+ self .chat_question .sql_user_question (current_time = datetime .now ().strftime ('%Y-%m-%d %H:%M:%S' ))))
507510
508511 self .current_logs [OperationEnum .GENERATE_SQL ] = start_log (session = self .session ,
509512 ai_modal_id = self .chat_question .ai_modal_id ,
@@ -670,9 +673,9 @@ def generate_assistant_filter(self, sql, tables: List):
670673 return None
671674 return self .build_table_filter (sql = sql , filters = filters )
672675
673- def generate_chart (self ):
676+ def generate_chart (self , chart_type : Optional [ str ] = '' ):
674677 # append current question
675- self .chart_message .append (HumanMessage (self .chat_question .chart_user_question ()))
678+ self .chart_message .append (HumanMessage (self .chat_question .chart_user_question (chart_type )))
676679
677680 self .current_logs [OperationEnum .GENERATE_CHART ] = start_log (session = self .session ,
678681 ai_modal_id = self .chat_question .ai_modal_id ,
@@ -714,7 +717,8 @@ def generate_chart(self):
714717 reasoning_content = full_thinking_text ,
715718 token_usage = token_usage )
716719
717- def check_sql (self , res : str ) -> tuple [any ]:
720+ @staticmethod
721+ def check_sql (res : str ) -> tuple [str , Optional [list ]]:
718722 json_str = extract_nested_json (res )
719723 if json_str is None :
720724 raise SingleMessageError (orjson .dumps ({'message' : 'Cannot parse sql from answer' ,
@@ -739,6 +743,26 @@ def check_sql(self, res: str) -> tuple[any]:
739743 raise SingleMessageError ("SQL query is empty" )
740744 return sql , data .get ('tables' )
741745
746+ @staticmethod
747+ def get_chart_type_from_sql_answer (res : str ) -> Optional [str ]:
748+ json_str = extract_nested_json (res )
749+ if json_str is None :
750+ return None
751+
752+ chart_type : Optional [str ]
753+ data : dict
754+ try :
755+ data = orjson .loads (json_str )
756+
757+ if data ['success' ]:
758+ chart_type = data ['chart-type' ]
759+ else :
760+ return None
761+ except Exception :
762+ return None
763+
764+ return chart_type
765+
742766 def check_save_sql (self , res : str ) -> str :
743767 sql , * _ = self .check_sql (res = res )
744768 save_sql (session = self .session , sql = sql , record_id = self .record .id )
@@ -921,6 +945,9 @@ def run_task(self, in_chat: bool = True):
921945
922946 # filter sql
923947 SQLBotLogUtil .info (full_sql_text )
948+
949+ chart_type = self .get_chart_type_from_sql_answer (full_sql_text )
950+
924951 use_dynamic_ds : bool = self .current_assistant and self .current_assistant .type in dynamic_ds_types
925952
926953 # todo row permission
@@ -962,7 +989,7 @@ def run_task(self, in_chat: bool = True):
962989 yield 'data:' + orjson .dumps ({'content' : 'execute-success' , 'type' : 'sql-data' }).decode () + '\n \n '
963990
964991 # generate chart
965- chart_res = self .generate_chart ()
992+ chart_res = self .generate_chart (chart_type )
966993 full_chart_text = ''
967994 for chunk in chart_res :
968995 full_chart_text += chunk .get ('content' )
0 commit comments