99if platform .system () != "Darwin" :
1010 import dmPython
1111import pymysql
12+ import redshift_connector
1213from sqlalchemy import create_engine , text , Engine
1314from sqlalchemy .orm import sessionmaker
1415
@@ -139,6 +140,19 @@ def check_connection(trans: Trans, ds: CoreDatasource, is_raise: bool = False):
139140 if is_raise :
140141 raise HTTPException (status_code = 500 , detail = trans ('i18n_ds_invalid' ) + f': { e .args } ' )
141142 return False
143+ elif ds .type == 'redshift' :
144+ with redshift_connector .connect (host = conf .host , port = conf .port , database = conf .database , user = conf .username ,
145+ password = conf .password ,
146+ timeout = 10 ) as conn , conn .cursor () as cursor :
147+ try :
148+ cursor .execute ('select 1' )
149+ SQLBotLogUtil .info ("success" )
150+ return True
151+ except Exception as e :
152+ SQLBotLogUtil .error (f"Datasource { ds .id } connection failed: { e } " )
153+ if is_raise :
154+ raise HTTPException (status_code = 500 , detail = trans ('i18n_ds_invalid' ) + f': { e .args } ' )
155+ return False
142156
143157
144158def get_version (ds : CoreDatasource ):
@@ -165,6 +179,8 @@ def get_version(ds: CoreDatasource):
165179 cursor .execute (sql )
166180 res = cursor .fetchall ()
167181 return res [0 ][0 ]
182+ elif ds .type == 'redshift' :
183+ return ''
168184 except Exception as e :
169185 print (e )
170186 return ''
@@ -194,6 +210,14 @@ def get_schema(ds: CoreDatasource):
194210 res = cursor .fetchall ()
195211 res_list = [item [0 ] for item in res ]
196212 return res_list
213+ elif ds .type == 'redshift' :
214+ with redshift_connector .connect (host = conf .host , port = conf .port , database = conf .database , user = conf .username ,
215+ password = conf .password ,
216+ timeout = conf .timeout ) as conn , conn .cursor () as cursor :
217+ cursor .execute (f"""SELECT nspname FROM pg_namespace""" )
218+ res = cursor .fetchall ()
219+ res_list = [item [0 ] for item in res ]
220+ return res_list
197221
198222
199223def get_tables (ds : CoreDatasource ):
@@ -222,6 +246,14 @@ def get_tables(ds: CoreDatasource):
222246 res = cursor .fetchall ()
223247 res_list = [TableSchema (* item ) for item in res ]
224248 return res_list
249+ elif ds .type == 'redshift' :
250+ with redshift_connector .connect (host = conf .host , port = conf .port , database = conf .database , user = conf .username ,
251+ password = conf .password ,
252+ timeout = conf .timeout ) as conn , conn .cursor () as cursor :
253+ cursor .execute (sql )
254+ res = cursor .fetchall ()
255+ res_list = [TableSchema (* item ) for item in res ]
256+ return res_list
225257
226258
227259def get_fields (ds : CoreDatasource , table_name : str = None ):
@@ -250,6 +282,14 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
250282 res = cursor .fetchall ()
251283 res_list = [ColumnSchema (* item ) for item in res ]
252284 return res_list
285+ elif ds .type == 'redshift' :
286+ with redshift_connector .connect (host = conf .host , port = conf .port , database = conf .database , user = conf .username ,
287+ password = conf .password ,
288+ timeout = conf .timeout ) as conn , conn .cursor () as cursor :
289+ cursor .execute (sql )
290+ res = cursor .fetchall ()
291+ res_list = [ColumnSchema (* item ) for item in res ]
292+ return res_list
253293
254294
255295def exec_sql (ds : CoreDatasource | AssistantOutDsSchema , sql : str , origin_column = False ):
@@ -311,3 +351,22 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
311351 "sql" : bytes .decode (base64 .b64encode (bytes (sql , 'utf-8' )))}
312352 except Exception as ex :
313353 raise ex
354+ elif ds .type == 'redshift' :
355+ with redshift_connector .connect (host = conf .host , port = conf .port , database = conf .database , user = conf .username ,
356+ password = conf .password ,
357+ timeout = conf .timeout ) as conn , conn .cursor () as cursor :
358+ try :
359+ cursor .execute (sql )
360+ res = cursor .fetchall ()
361+ columns = [field [0 ] for field in cursor .description ] if origin_column else [field [0 ].lower () for
362+ field in
363+ cursor .description ]
364+ result_list = [
365+ {str (columns [i ]): float (value ) if isinstance (value , Decimal ) else value for i , value in
366+ enumerate (tuple_item )}
367+ for tuple_item in res
368+ ]
369+ return {"fields" : columns , "data" : result_list ,
370+ "sql" : bytes .decode (base64 .b64encode (bytes (sql , 'utf-8' )))}
371+ except Exception as ex :
372+ raise ex
0 commit comments