11import logging
22import traceback
33import warnings
4- from typing import Any , List , Union , Dict
4+ from typing import Any , List , Optional , Union , Dict
55
66import numpy as np
77import orjson
2323from apps .datasource .crud .datasource import get_table_schema
2424from apps .datasource .models .datasource import CoreDatasource
2525from apps .db .db import exec_sql
26+ from apps .system .crud .assistant import get_assistant_ds
2627from common .core .config import settings
27- from common .core .deps import SessionDep , CurrentUser
28+ from common .core .deps import CurrentAssistant , SessionDep , CurrentUser
2829from common .utils .utils import extract_nested_json
2930
3031warnings .filterwarnings ("ignore" )
@@ -42,13 +43,14 @@ class LLMService:
4243 chart_message : List [Union [BaseMessage , dict [str , Any ]]] = []
4344 history_records : List [ChatRecord ] = []
4445 session : SessionDep
45- _current_user : CurrentUser
46+ current_user : CurrentUser
47+ current_assistant : Optional [CurrentAssistant ] = None
4648
47- def __init__ (self , session : SessionDep , current_user : CurrentUser , chat_question : ChatQuestion ):
49+ def __init__ (self , session : SessionDep , current_user : CurrentUser , chat_question : ChatQuestion , current_assistant : Optional [ CurrentAssistant ] = None ):
4850
4951 self .session = session
5052 self .current_user = current_user
51-
53+ self . current_assistant = current_assistant
5254 #chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
5355 chat_id = chat_question .chat_id
5456 chat : Chat = self .session .get (Chat , chat_id )
@@ -332,8 +334,11 @@ def generate_recommend_questions_task(self):
332334 def select_datasource (self ):
333335 datasource_msg : List [Union [BaseMessage , dict [str , Any ]]] = []
334336 datasource_msg .append (SystemMessage (self .chat_question .datasource_sys_question ()))
335- _ds_list = self .session .exec (select (CoreDatasource ).options (
336- load_only (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description ))).all ()
337+ if self .current_assistant :
338+ _ds_list = get_assistant_ds (session = self .session , assistant = self .current_assistant )
339+ else :
340+ _ds_list = self .session .exec (select (CoreDatasource ).options (
341+ load_only (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description ))).all ()
337342 _ds_list_dict = []
338343 for _ds in _ds_list :
339344 _ds_list_dict .append ({'id' : _ds [0 ].id , 'name' : _ds [0 ].name , 'description' : _ds [0 ].description })
0 commit comments