11
22
33import json
4+ from typing import Optional
45from fastapi import FastAPI
6+ import requests
57from sqlmodel import Session , select
8+ from apps .chat .task .llm import LLMService
69from apps .datasource .models .datasource import CoreDatasource
710from apps .system .models .system_model import AssistantModel
811from apps .system .schemas .auth import CacheName , CacheNamespace
1114from common .core .db import engine
1215from starlette .middleware .cors import CORSMiddleware
1316from common .core .config import settings
17+ from deps import CurrentUser
1418
1519@cache (namespace = CacheNamespace .EMBEDDED_INFO , cacheName = CacheName .ASSISTANT_INFO , keyExpression = "assistant_id" )
1620async def get_assistant_info (* , session : Session , assistant_id : int ) -> AssistantModel | None :
@@ -20,21 +24,26 @@ async def get_assistant_info(*, session: Session, assistant_id: int) -> Assistan
2024def get_assistant_user (* , id : int ):
2125 return UserInfoDTO (
id = id ,
account = "sqlbot-inner-assistant" ,
oid = 1 ,
name = "sqlbot-inner-assistant" ,
email = "[email protected] " )
2226
23- def get_assistant_ds (* , session : Session , assistant : AssistantModel ):
27+ # def get_assistant_ds(*, session: Session, assistant: AssistantModel):
28+ def get_assistant_ds (llm_service : LLMService ) -> list [dict ]:
29+ assistant : AssistantModel = llm_service .current_assistant
30+ session : Session = llm_service .session
2431 type = assistant .type
2532 if type == 0 :
26- stmt = select (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description )
2733 configuration = assistant .configuration
2834 if configuration :
29- config = json .loads (configuration )
35+ config : dict [any ] = json .loads (configuration )
36+ oid : str = config ['oid' ]
37+ stmt = select (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description ).where (CoreDatasource .oid == oid )
3038 private_list :list [int ] = config ['private_list' ]
31- if not private_list :
39+ if private_list :
3240 stmt .where (~ CoreDatasource .id .in_ (private_list ))
3341 db_ds_list = session .exec (stmt ).all ()
3442 # filter private ds if offline
3543 return db_ds_list
36- out_ds_instance : AssistantOutDs = AssistantOutDsFactory .get_instance (assistant )
37- dslist = out_ds_instance .get_ds_list ()
44+ out_ds_instance : AssistantOutDs = AssistantOutDsFactory .get_instance (assistant , llm_service .assistant_certificate )
45+ llm_service .out_ds_instance = out_ds_instance
46+ dslist = out_ds_instance .get_simple_ds_list ()
3847 # format?
3948 return dslist
4049
@@ -66,16 +75,66 @@ def init_dynamic_cors(app: FastAPI):
6675
6776class AssistantOutDs :
6877 assistant : AssistantModel
69- def get_ds_list (self ):
78+ ds_list : Optional [list [dict ]] = None
79+ certificate : Optional [str ] = None
80+ def __init__ (self , assistant : AssistantModel , certificate : Optional [str ] = None ):
81+ self .assistant = assistant
82+ self .ds_list = None
83+ self .certificate = certificate
84+ self .get_ds_from_api (certificate )
85+
86+ #@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
87+ async def get_ds_from_api (self , certificate : Optional [str ] = None ):
7088 config : dict [any ] = json .loads (self .assistant .configuration )
71- url : str = config ['url' ]
89+ endpoint : str = config ['endpoint' ]
90+ certificateList : list [any ] = json .loads (certificate )
91+ header = {}
92+ cookies = {}
93+ for item in certificateList :
94+ if item ['target' ] == 'head' :
95+ header [item ['key' ]] = item ['value' ]
96+ if item ['target' ] == 'cookie' :
97+ cookies [item ['key' ]] = item ['value' ]
98+
99+ res = requests .get (url = endpoint , headers = header , cookies = cookies , timeout = 10 )
100+ if res .status_code == 200 :
101+ result_json : dict [any ] = json .loads (res .json ())
102+ if result_json .get ('code' ) == 0 :
103+ temp_list = result_json .get ('data' , [])
104+ for idx , item in enumerate (temp_list , start = 1 ):
105+ item ["id" ] = idx
106+ self .ds_list = temp_list
107+ return self .ds_list
108+ else :
109+ raise Exception (f"Failed to get datasource list from { endpoint } , error: { result_json .get ('message' )} " )
110+ else :
111+ raise Exception (f"Failed to get datasource list from { endpoint } , status code: { res .status_code } " )
112+
113+ def get_simple_ds_list (self ):
114+ if self .ds_list :
115+ return [{'id' : ds ['id' ], 'name' : ds ['name' ], 'description' : ds ['comment' ]} for ds in self .ds_list ]
116+ else :
117+ raise Exception ("Datasource list is not found." )
118+
119+ def get_db_schema (self , ds_id : int ):
72120 return None
121+ def get_ds (self , ds_id : int ):
122+ if self .ds_list :
123+ for ds in self .ds_list :
124+ if ds ['id' ] == ds_id :
125+ return ds
126+ else :
127+ raise Exception ("Datasource list is not found." )
128+ raise Exception (f"Datasource with id { ds_id } not found." )
129+ def get_ds_engine (self , ds_id : int ):
130+ ds = self .get_ds (ds_id )
131+ ds_type = ds .get ('type' ) if ds else None
132+ if not ds_type :
133+ raise Exception (f"Datasource with id { ds_id } not found or type is not defined." )
134+ return ds_type
73135
74136class AssistantOutDsFactory :
75- _instance : AssistantOutDs = None
76137 @staticmethod
77- def get_instance (cls , assistant : AssistantModel ) -> AssistantOutDs :
78- if not cls ._instance :
79- cls ._instance = AssistantOutDs (assistant )
80- return cls ._instance
138+ def get_instance (assistant : AssistantModel , certificate : Optional [str ] = None ) -> AssistantOutDs :
139+ return AssistantOutDs (assistant , certificate )
81140
0 commit comments