1-
2-
31import json
2+ import urllib
43from typing import Optional
5- from fastapi import FastAPI
4+
65import requests
6+ from fastapi import FastAPI
77from sqlalchemy import Engine , create_engine
88from sqlmodel import Session , select
9- from apps . datasource . models . datasource import CoreDatasource , DatasourceConf
9+ from starlette . middleware . cors import CORSMiddleware
1010
11+ from apps .datasource .models .datasource import CoreDatasource , DatasourceConf
1112from apps .system .models .system_model import AssistantModel
1213from apps .system .schemas .auth import CacheName , CacheNamespace
1314from apps .system .schemas .system_schema import AssistantHeader , AssistantOutDsSchema , UserInfoDTO
14- from common .core .sqlbot_cache import cache
15- from common .core .db import engine
16- from starlette .middleware .cors import CORSMiddleware
1715from common .core .config import settings
16+ from common .core .db import engine
17+ from common .core .sqlbot_cache import cache
1818from common .utils .utils import string_to_numeric_hash
1919
20+
2021@cache (namespace = CacheNamespace .EMBEDDED_INFO , cacheName = CacheName .ASSISTANT_INFO , keyExpression = "assistant_id" )
2122async def get_assistant_info (* , session : Session , assistant_id : int ) -> AssistantModel | None :
2223 db_model = session .get (AssistantModel , assistant_id )
2324 return db_model
2425
26+
2527def get_assistant_user (* , id : int ):
26- return UserInfoDTO (
id = id ,
account = "sqlbot-inner-assistant" ,
oid = 1 ,
name = "sqlbot-inner-assistant" ,
email = "[email protected] " )
28+ return UserInfoDTO (id = id , account = "sqlbot-inner-assistant" , oid = 1 , name = "sqlbot-inner-assistant" ,
29+ 30+
2731
2832def get_assistant_ds (llm_service ) -> list [dict ]:
2933 assistant : AssistantHeader = llm_service .current_assistant
@@ -34,13 +38,14 @@ def get_assistant_ds(llm_service) -> list[dict]:
3438 if configuration :
3539 config : dict [any ] = json .loads (configuration )
3640 oid : int = int (config ['oid' ])
37- stmt = select (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description ).where (CoreDatasource .oid == oid )
41+ stmt = select (CoreDatasource .id , CoreDatasource .name , CoreDatasource .description ).where (
42+ CoreDatasource .oid == oid )
3843 if not assistant .online :
39- private_list :list [int ] = config .get ('private_list' ) or None
44+ private_list : list [int ] = config .get ('private_list' ) or None
4045 if private_list :
4146 stmt = stmt .where (~ CoreDatasource .id .in_ (private_list ))
4247 db_ds_list = session .exec (stmt )
43-
48+
4449 result_list = [
4550 {
4651 "id" : ds .id ,
@@ -49,7 +54,7 @@ def get_assistant_ds(llm_service) -> list[dict]:
4954 }
5055 for ds in db_ds_list
5156 ]
52-
57+
5358 # filter private ds if offline
5459 return result_list
5560 out_ds_instance : AssistantOutDs = AssistantOutDsFactory .get_instance (assistant )
@@ -58,8 +63,9 @@ def get_assistant_ds(llm_service) -> list[dict]:
5863 # format?
5964 return dslist
6065
66+
6167def init_dynamic_cors (app : FastAPI ):
62- try :
68+ try :
6369 with Session (engine ) as session :
6470 list_result = session .exec (select (AssistantModel ).order_by (AssistantModel .create_time )).all ()
6571 seen = set ()
@@ -81,20 +87,20 @@ def init_dynamic_cors(app: FastAPI):
8187 cors_middleware .kwargs ['allow_origins' ] = updated_origins
8288 except Exception as e :
8389 return False , e
84-
85-
90+
8691
8792class AssistantOutDs :
8893 assistant : AssistantHeader
8994 ds_list : Optional [list [AssistantOutDsSchema ]] = None
9095 certificate : Optional [str ] = None
96+
9197 def __init__ (self , assistant : AssistantHeader ):
9298 self .assistant = assistant
9399 self .ds_list = None
94100 self .certificate = assistant .certificate
95101 self .get_ds_from_api ()
96-
97- #@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
102+
103+ # @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
98104 def get_ds_from_api (self ):
99105 config : dict [any ] = json .loads (self .assistant .configuration )
100106 endpoint : str = config ['endpoint' ]
@@ -118,23 +124,23 @@ def get_ds_from_api(self):
118124 self .convert2schema (item )
119125 for item in temp_list
120126 ]
121-
127+
122128 return self .ds_list
123129 else :
124130 raise Exception (f"Failed to get datasource list from { endpoint } , error: { result_json .get ('message' )} " )
125131 else :
126132 raise Exception (f"Failed to get datasource list from { endpoint } , status code: { res .status_code } " )
127-
133+
128134 def get_simple_ds_list (self ):
129135 if self .ds_list :
130136 return [{'id' : ds .id , 'name' : ds .name , 'description' : ds .comment } for ds in self .ds_list ]
131137 else :
132138 raise Exception ("Datasource list is not found." )
133-
139+
134140 def get_db_schema (self , ds_id : int ) -> str :
135141 ds = self .get_ds (ds_id )
136142 schema_str = ""
137- #db_name = ds.db_schema
143+ # db_name = ds.db_schema
138144 db_name = ds .db_schema if ds .db_schema is not None and ds .db_schema != "" else ds .dataBase
139145 schema_str += f"【DB_ID】 { db_name } \n 【Schema】\n "
140146 for table in ds .tables :
@@ -144,7 +150,7 @@ def get_db_schema(self, ds_id: int) -> str:
144150 schema_str += '\n [\n '
145151 else :
146152 schema_str += f", { table_comment } \n [\n "
147-
153+
148154 field_list = []
149155 for field in table .fields :
150156 field_comment = field .comment
@@ -155,7 +161,7 @@ def get_db_schema(self, ds_id: int) -> str:
155161 schema_str += ",\n " .join (field_list )
156162 schema_str += '\n ]\n '
157163 return schema_str
158-
164+
159165 def get_ds (self , ds_id : int ):
160166 if self .ds_list :
161167 for ds in self .ds_list :
@@ -175,20 +181,22 @@ def convert2schema(self, ds_dict: dict) -> AssistantOutDsSchema:
175181 db_schema = ds_dict .get ('schema' , ds_dict .get ('db_schema' , '' ))
176182 ds_dict .pop ("schema" , None )
177183 return AssistantOutDsSchema (** {** ds_dict , "id" : id , "db_schema" : db_schema })
178-
184+
185+
179186class AssistantOutDsFactory :
180187 @staticmethod
181188 def get_instance (assistant : AssistantHeader ) -> AssistantOutDs :
182189 return AssistantOutDs (assistant )
183190
191+
184192def get_ds_engine (ds : AssistantOutDsSchema ) -> Engine :
185193 timeout : int = 30
186194 connect_args = {"connect_timeout" : timeout }
187195 conf = DatasourceConf (
188- host = ds .host ,
189- port = ds .port ,
196+ host = ds .host ,
197+ port = ds .port ,
190198 username = ds .user ,
191- password = ds .password ,
199+ password = ds .password ,
192200 database = ds .dataBase ,
193201 driver = '' ,
194202 extraJdbc = ds .extraParams ,
@@ -197,8 +205,26 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
197205 conf .extraJdbc = ''
198206 from apps .db .db import get_uri_from_config
199207 uri = get_uri_from_config (ds .type , conf )
208+ # if ds.type == "pg" and ds.db_schema:
209+ # connect_args.update({"options": f"-c search_path={ds.db_schema}"})
210+ # engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
211+
200212 if ds .type == "pg" and ds .db_schema :
201- connect_args .update ({"options" : f"-c search_path={ ds .db_schema } " })
202- engine = create_engine (uri , connect_args = connect_args , pool_timeout = timeout , pool_size = 20 , max_overflow = 10 )
213+ engine = create_engine (uri ,
214+ connect_args = {"options" : f"-c search_path={ urllib .parse .quote (ds .db_schema )} " ,
215+ "connect_timeout" : timeout },
216+ pool_timeout = timeout , pool_size = 20 , max_overflow = 10 )
217+ elif ds .type == 'sqlServer' :
218+ engine = create_engine (uri , pool_timeout = timeout ,
219+ pool_size = 20 ,
220+ max_overflow = 10 )
221+ elif ds .type == 'oracle' :
222+ engine = create_engine (uri ,
223+ pool_timeout = timeout ,
224+ pool_size = 20 ,
225+ max_overflow = 10 )
226+ else :
227+ engine = create_engine (uri , connect_args = {"connect_timeout" : timeout }, pool_timeout = timeout ,
228+ pool_size = 20 ,
229+ max_overflow = 10 )
203230 return engine
204-
0 commit comments