1616from sqlalchemy import and_ , cast , or_
1717from sqlalchemy import select
1818from sqlalchemy .dialects .postgresql import JSONB
19+ from sqlalchemy .orm import sessionmaker
1920from sqlbot_xpack .permissions .api .permission import transRecord2DTO
2021from sqlbot_xpack .permissions .models .ds_permission import DsPermission , PermissionDTO
2122from sqlbot_xpack .permissions .models .ds_rules import DsRules
23+ from sqlmodel import create_engine , Session
2224
2325from apps .ai_model .model_factory import LLMConfig , LLMFactory , get_default_config
2426from apps .chat .curd .chat import save_question , save_full_sql_message , save_full_sql_message_and_answer , save_sql , \
3537from apps .system .crud .assistant import AssistantOutDs , AssistantOutDsFactory , get_assistant_ds
3638from apps .system .schemas .system_schema import AssistantOutDsSchema
3739from common .core .config import settings
38- from common .core .deps import CurrentAssistant , SessionDep , CurrentUser
40+ from common .core .deps import CurrentAssistant , CurrentUser
3941from common .utils .utils import SQLBotLogUtil , extract_nested_json
4042
4143warnings .filterwarnings ("ignore" )
@@ -54,7 +56,7 @@ class LLMService:
5456 sql_message : List [Union [BaseMessage , dict [str , Any ]]] = []
5557 chart_message : List [Union [BaseMessage , dict [str , Any ]]] = []
5658 history_records : List [ChatRecord ] = []
57- session : SessionDep
59+ session : Session
5860 current_user : CurrentUser
5961 current_assistant : Optional [CurrentAssistant ] = None
6062 out_ds_instance : Optional [AssistantOutDs ] = None
@@ -63,15 +65,18 @@ class LLMService:
6365 chunk_list : List [str ] = []
6466 future : Future
6567
66- def __init__ (self , session : SessionDep , current_user : CurrentUser , chat_question : ChatQuestion ,
68+ def __init__ (self , current_user : CurrentUser , chat_question : ChatQuestion ,
6769 current_assistant : Optional [CurrentAssistant ] = None ):
6870 self .chunk_list = []
69- self .session = session
71+ engine = create_engine (str (settings .SQLALCHEMY_DATABASE_URI ))
72+ session_maker = sessionmaker (bind = engine )
73+ self .session = session_maker ()
74+
7075 self .current_user = current_user
7176 self .current_assistant = current_assistant
7277 # chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
7378 chat_id = chat_question .chat_id
74- chat : Chat = self .session .get (Chat , chat_id )
79+ chat : Chat | None = self .session .get (Chat , chat_id )
7580 if not chat :
7681 raise Exception (f"Chat with id { chat_id } not found" )
7782 ds : CoreDatasource | AssistantOutDsSchema | None = None
@@ -375,10 +380,10 @@ def select_datasource(self):
375380 ]
376381 """ _ds_list = self.session.exec(select(CoreDatasource).options(
377382 load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all() """
378-
383+
379384 ignore_auto_select = _ds_list and len (_ds_list ) == 1
380385 # ignore auto select ds
381-
386+
382387 if not ignore_auto_select :
383388 _ds_list_dict = []
384389 for _ds in _ds_list :
@@ -391,10 +396,10 @@ def select_datasource(self):
391396 history_msg = orjson .loads (self .record .full_select_datasource_message )
392397
393398 self .record = save_full_select_datasource_message_and_answer (session = self .session , record_id = self .record .id ,
394- answer = '' ,
395- full_message = orjson .dumps (history_msg +
396- [{'type' : msg .type ,
397- 'content' : msg .content }
399+ answer = '' ,
400+ full_message = orjson .dumps (history_msg +
401+ [{'type' : msg .type ,
402+ 'content' : msg .content }
398403 for msg
399404 in
400405 datasource_msg ]).decode ())
@@ -461,13 +466,13 @@ def select_datasource(self):
461466
462467 if not ignore_auto_select :
463468 self .record = save_full_select_datasource_message_and_answer (session = self .session , record_id = self .record .id ,
464- answer = orjson .dumps ({'content' : full_text ,
465- 'reasoning_content' : full_thinking_text }).decode (),
466- datasource = _datasource ,
467- engine_type = _engine_type ,
468- full_message = orjson .dumps (history_msg +
469- [{'type' : msg .type ,
470- 'content' : msg .content }
469+ answer = orjson .dumps ({'content' : full_text ,
470+ 'reasoning_content' : full_thinking_text }).decode (),
471+ datasource = _datasource ,
472+ engine_type = _engine_type ,
473+ full_message = orjson .dumps (history_msg +
474+ [{'type' : msg .type ,
475+ 'content' : msg .content }
471476 for msg
472477 in
473478 datasource_msg ]).decode ())
@@ -511,7 +516,6 @@ def generate_sql(self):
511516 [{'type' : msg .type , 'content' : msg .content } for msg in
512517 self .sql_message ]).decode ())
513518
514-
515519 def build_table_filter (self , sql : str , filters : list ):
516520 filter = json .dumps (filters , ensure_ascii = False )
517521 self .chat_question .sql = sql
@@ -561,7 +565,7 @@ def build_table_filter(self, sql: str, filters: list):
561565 # analysis_msg]).decode())
562566 SQLBotLogUtil .info (full_filter_text )
563567 return full_filter_text
564-
568+
565569 def generate_filter (self , sql : str , tables : List ):
566570 table_list = self .session .query (CoreTable ).filter (
567571 and_ (CoreTable .ds_id == self .ds .id , CoreTable .table_name .in_ (tables ))
@@ -586,7 +590,7 @@ def generate_filter(self, sql: str, tables: List):
586590 filters .append ({"table" : table .table_name , "filter" : where_str })
587591
588592 return self .build_table_filter (sql = sql , filters = filters )
589-
593+
590594 def generate_assistant_filter (self , sql , tables : List ):
591595 ds : AssistantOutDsSchema = self .ds
592596 filters = []
@@ -596,7 +600,7 @@ def generate_assistant_filter(self, sql, tables: List):
596600 if not filters :
597601 return None
598602 return self .build_table_filter (sql = sql , filters = filters )
599-
603+
600604 def generate_chart (self ):
601605 # append current question
602606 self .chart_message .append (HumanMessage (self .chat_question .chart_user_question ()))
@@ -810,7 +814,8 @@ def run_task(self, in_chat: bool = True):
810814 SQLBotLogUtil .info (full_sql_text )
811815
812816 # todo row permission
813- if (not self .current_assistant and is_normal_user (self .current_user )) or (self .current_assistant and self .current_assistant .type == 1 ):
817+ if (not self .current_assistant and is_normal_user (self .current_user )) or (
818+ self .current_assistant and self .current_assistant .type == 1 ):
814819 sql_json_str = extract_nested_json (full_sql_text )
815820 data = orjson .loads (sql_json_str )
816821
@@ -831,7 +836,7 @@ def run_task(self, in_chat: bool = True):
831836 sql_result = self .generate_assistant_filter (data .get ('sql' ), data .get ('tables' ))
832837 else :
833838 sql_result = self .generate_filter (data .get ('sql' ), data .get ('tables' )) # maybe no sql and tables
834-
839+
835840 if sql_result :
836841 SQLBotLogUtil .info (sql_result )
837842 sql = self .check_save_sql (res = sql_result )
@@ -1060,6 +1065,7 @@ def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
10601065 except Exception :
10611066 pass
10621067
1068+
10631069def get_lang_name (lang : str ):
10641070 if lang and lang == 'en' :
10651071 return '英文'
0 commit comments