@@ -36,7 +36,11 @@ def _get_connection_attributes_from_catalog(
3636 details : Dict [str , Any ] = get_connection (name = connection , catalog_id = catalog_id , boto3_session = boto3_session )[
3737 "ConnectionProperties"
3838 ]
39- port , database = details ["JDBC_CONNECTION_URL" ].split (":" )[3 ].split ("/" )
39+ if ";databaseName=" in details ["JDBC_CONNECTION_URL" ]:
40+ database_sep = ";databaseName="
41+ else :
42+ database_sep = "/"
43+ port , database = details ["JDBC_CONNECTION_URL" ].split (":" )[3 ].split (database_sep )
4044 return ConnectionAttributes (
4145 kind = details ["JDBC_CONNECTION_URL" ].split (":" )[1 ].lower (),
4246 user = details ["USERNAME" ],
@@ -136,19 +140,48 @@ def _records2df(
136140 return df
137141
138142
139- def _iterate_cursor (
140- cursor : Any ,
143+ def _get_cols_names (cursor_description : Any ) -> List [str ]:
144+ cols_names = [col [0 ].decode ("utf-8" ) if isinstance (col [0 ], bytes ) else col [0 ] for col in cursor_description ]
145+ _logger .debug ("cols_names: %s" , cols_names )
146+
147+ return cols_names
148+
149+
150+ def _iterate_results (
151+ con : Any ,
152+ cursor_args : List [Any ],
141153 chunksize : int ,
142- cols_names : List [str ],
143- index : Optional [Union [str , List [str ]]],
154+ index_col : Optional [Union [str , List [str ]]],
144155 safe : bool ,
145156 dtype : Optional [Dict [str , pa .DataType ]],
146157) -> Iterator [pd .DataFrame ]:
147- while True :
148- records = cursor .fetchmany (chunksize )
149- if not records :
150- break
151- yield _records2df (records = records , cols_names = cols_names , index = index , safe = safe , dtype = dtype )
158+ with con .cursor () as cursor :
159+ cursor .execute (* cursor_args )
160+ cols_names = _get_cols_names (cursor .description )
161+ while True :
162+ records = cursor .fetchmany (chunksize )
163+ if not records :
164+ break
165+ yield _records2df (records = records , cols_names = cols_names , index = index_col , safe = safe , dtype = dtype )
166+
167+
168+ def _fetch_all_results (
169+ con : Any ,
170+ cursor_args : List [Any ],
171+ index_col : Optional [Union [str , List [str ]]] = None ,
172+ dtype : Optional [Dict [str , pa .DataType ]] = None ,
173+ safe : bool = True ,
174+ ) -> pd .DataFrame :
175+ with con .cursor () as cursor :
176+ cursor .execute (* cursor_args )
177+ cols_names = _get_cols_names (cursor .description )
178+ return _records2df (
179+ records = cast (List [Tuple [Any ]], cursor .fetchall ()),
180+ cols_names = cols_names ,
181+ index = index_col ,
182+ dtype = dtype ,
183+ safe = safe ,
184+ )
152185
153186
154187def read_sql_query (
@@ -163,23 +196,23 @@ def read_sql_query(
163196 """Read SQL Query (generic)."""
164197 args = _convert_params (sql , params )
165198 try :
166- with con .cursor () as cursor :
167- cursor .execute (* args )
168- cols_names : List [str ] = [
169- col [0 ].decode ("utf-8" ) if isinstance (col [0 ], bytes ) else col [0 ] for col in cursor .description
170- ]
171- _logger .debug ("cols_names: %s" , cols_names )
172- if chunksize is None :
173- return _records2df (
174- records = cast (List [Tuple [Any ]], cursor .fetchall ()),
175- cols_names = cols_names ,
176- index = index_col ,
177- dtype = dtype ,
178- safe = safe ,
179- )
180- return _iterate_cursor (
181- cursor = cursor , chunksize = chunksize , cols_names = cols_names , index = index_col , dtype = dtype , safe = safe
199+ if chunksize is None :
200+ return _fetch_all_results (
201+ con = con ,
202+ cursor_args = args ,
203+ index_col = index_col ,
204+ dtype = dtype ,
205+ safe = safe ,
182206 )
207+
208+ return _iterate_results (
209+ con = con ,
210+ cursor_args = args ,
211+ chunksize = chunksize ,
212+ index_col = index_col ,
213+ dtype = dtype ,
214+ safe = safe ,
215+ )
183216 except Exception as ex :
184217 con .rollback ()
185218 _logger .error (ex )
0 commit comments