@@ -155,29 +155,15 @@ def read_sql_query(
155155 ... )
156156
157157 """
158- if not isinstance (con , sqlalchemy .engine .Engine ): # pragma: no cover
159- raise exceptions .InvalidConnection (
160- "Invalid 'con' argument, please pass a "
161- "SQLAlchemy Engine. Use wr.db.get_engine(), "
162- "wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
163- )
158+ _validate_engine (con = con )
164159 with con .connect () as _con :
165160 args = _convert_params (sql , params )
166161 cursor = _con .execute (* args )
167162 if chunksize is None :
168163 return _records2df (records = cursor .fetchall (), cols_names = cursor .keys (), index = index_col , dtype = dtype )
169- return _iterate_cursor (cursor = cursor , chunksize = chunksize , index = index_col , dtype = dtype )
170-
171-
172- def _iterate_cursor (
173- cursor , chunksize : int , index : Optional [Union [str , List [str ]]], dtype : Optional [Dict [str , pa .DataType ]] = None
174- ) -> Iterator [pd .DataFrame ]:
175- while True :
176- records = cursor .fetchmany (chunksize )
177- if not records :
178- break
179- df : pd .DataFrame = _records2df (records = records , cols_names = cursor .keys (), index = index , dtype = dtype )
180- yield df
164+ return _iterate_cursor (
165+ cursor = cursor , chunksize = chunksize , cols_names = cursor .keys (), index = index_col , dtype = dtype
166+ )
181167
182168
183169def _records2df (
@@ -207,6 +193,20 @@ def _records2df(
207193 return df
208194
209195
196+ def _iterate_cursor (
197+ cursor : Any ,
198+ chunksize : int ,
199+ cols_names : List [str ],
200+ index : Optional [Union [str , List [str ]]],
201+ dtype : Optional [Dict [str , pa .DataType ]] = None ,
202+ ) -> Iterator [pd .DataFrame ]:
203+ while True :
204+ records = cursor .fetchmany (chunksize )
205+ if not records :
206+ break
207+ yield _records2df (records = records , cols_names = cols_names , index = index , dtype = dtype )
208+
209+
210210def _convert_params (sql : str , params : Optional [Union [List , Tuple , Dict ]]) -> List [Any ]:
211211 args : List [Any ] = [sql ]
212212 if params is not None :
@@ -646,7 +646,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
646646 athena_types , _ = s3 .read_parquet_metadata (
647647 path = paths , dataset = False , use_threads = use_threads , boto3_session = session
648648 )
649- _logger .debug (f "athena_types: { athena_types } " )
649+ _logger .debug ("athena_types: %s" , athena_types )
650650 redshift_types : Dict [str , str ] = {}
651651 for col_name , col_type in athena_types .items ():
652652 length : int = _varchar_lengths [col_name ] if col_name in _varchar_lengths else varchar_lengths_default
@@ -680,7 +680,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
680680def _rs_upsert (con : Any , table : str , temp_table : str , schema : str , primary_keys : Optional [List [str ]] = None ) -> None :
681681 if not primary_keys :
682682 primary_keys = _rs_get_primary_keys (con = con , schema = schema , table = table )
683- _logger .debug (f "primary_keys: { primary_keys } " )
683+ _logger .debug ("primary_keys: %s" , primary_keys )
684684 if not primary_keys : # pragma: no cover
685685 raise exceptions .InvalidRedshiftPrimaryKeys ()
686686 equals_clause : str = f"{ table } .%s = { temp_table } .%s"
@@ -735,7 +735,7 @@ def _rs_create_table(
735735 f"{ distkey_str } "
736736 f"{ sortkey_str } "
737737 )
738- _logger .debug (f "Create table query:\n { sql } " )
738+ _logger .debug ("Create table query:\n %s" , sql )
739739 con .execute (sql )
740740 return table , schema
741741
@@ -746,7 +746,7 @@ def _rs_validate_parameters(
746746 if diststyle not in _RS_DISTSTYLES :
747747 raise exceptions .InvalidRedshiftDiststyle (f"diststyle must be in { _RS_DISTSTYLES } " )
748748 cols = list (redshift_types .keys ())
749- _logger .debug (f "Redshift columns: { cols } " )
749+ _logger .debug ("Redshift columns: %s" , cols )
750750 if (diststyle == "KEY" ) and (not distkey ):
751751 raise exceptions .InvalidRedshiftDistkey ("You must pass a distkey if you intend to use KEY diststyle" )
752752 if distkey and distkey not in cols :
@@ -775,13 +775,13 @@ def _rs_copy(
775775 sql : str = (
776776 f"COPY { table_name } FROM '{ manifest_path } '\n " f"IAM_ROLE '{ iam_role } '\n " "MANIFEST\n " "FORMAT AS PARQUET"
777777 )
778- _logger .debug (f "copy query:\n { sql } " )
778+ _logger .debug ("copy query:\n %s" , sql )
779779 con .execute (sql )
780780 sql = "SELECT pg_last_copy_id() AS query_id"
781781 query_id : int = con .execute (sql ).fetchall ()[0 ][0 ]
782782 sql = f"SELECT COUNT(DISTINCT filename) as num_files_loaded " f"FROM STL_LOAD_COMMITS WHERE query = { query_id } "
783783 num_files_loaded : int = con .execute (sql ).fetchall ()[0 ][0 ]
784- _logger .debug (f" { num_files_loaded } files counted. { num_files } expected." )
784+ _logger .debug ("%s files counted. %s expected.", num_files_loaded , num_files )
785785 if num_files_loaded != num_files : # pragma: no cover
786786 raise exceptions .RedshiftLoadError (
787787 f"Redshift load rollbacked. { num_files_loaded } files counted. { num_files } expected."
@@ -846,17 +846,17 @@ def write_redshift_copy_manifest(
846846 payload : str = json .dumps (manifest )
847847 bucket : str
848848 bucket , key = _utils .parse_path (manifest_path )
849- _logger .debug (f "payload: { payload } " )
849+ _logger .debug ("payload: %s" , payload )
850850 client_s3 : boto3 .client = _utils .client (service_name = "s3" , session = session )
851- _logger .debug (f "bucket: { bucket } " )
852- _logger .debug (f "key: { key } " )
851+ _logger .debug ("bucket: %s" , bucket )
852+ _logger .debug ("key: %s" , key )
853853 client_s3 .put_object (Body = payload , Bucket = bucket , Key = key )
854854 return manifest
855855
856856
857857def _rs_drop_table (con : Any , schema : str , table : str ) -> None :
858858 sql = f"DROP TABLE IF EXISTS { schema } .{ table } "
859- _logger .debug (f "Drop table query:\n { sql } " )
859+ _logger .debug ("Drop table query:\n %s" , sql )
860860 con .execute (sql )
861861
862862
@@ -1104,5 +1104,14 @@ def unload_redshift_to_files(
11041104 query_id : int = _con .execute (sql ).fetchall ()[0 ][0 ]
11051105 sql = f"SELECT path FROM STL_UNLOAD_LOG WHERE query={ query_id } ;"
11061106 paths = [x [0 ].replace (" " , "" ) for x in _con .execute (sql ).fetchall ()]
1107- _logger .debug (f "paths: { paths } " )
1107+ _logger .debug ("paths: %s" , paths )
11081108 return paths
1109+
1110+
1111+ def _validate_engine (con : sqlalchemy .engine .Engine ) -> None : # pragma: no cover
1112+ if not isinstance (con , sqlalchemy .engine .Engine ):
1113+ raise exceptions .InvalidConnection (
1114+ "Invalid 'con' argument, please pass a "
1115+ "SQLAlchemy Engine. Use wr.db.get_engine(), "
1116+ "wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
1117+ )
0 commit comments