1818from s3fs import S3FileSystem # type: ignore
1919
2020from awswrangler import data_types
21+ from awswrangler import utils
2122from awswrangler .exceptions import (UnsupportedWriteMode , UnsupportedFileFormat , AthenaQueryError , EmptyS3Object ,
2223 LineTerminatorNotFound , EmptyDataframe , InvalidSerDe , InvalidCompression ,
2324 InvalidParameters , InvalidEngine )
@@ -122,7 +123,8 @@ def read_csv(
122123 encoding = encoding ,
123124 converters = converters )
124125 else :
125- ret = self ._read_csv_once (bucket_name = bucket_name ,
126+ ret = self ._read_csv_once (session_primitives = self ._session .primitives ,
127+ bucket_name = bucket_name ,
126128 key_path = key_path ,
127129 header = header ,
128130 names = names ,
@@ -193,7 +195,8 @@ def _read_csv_iterator(
193195 if total_size <= 0 :
194196 raise EmptyS3Object (metadata )
195197 elif total_size <= max_result_size :
196- yield self ._read_csv_once (bucket_name = bucket_name ,
198+ yield self ._read_csv_once (session_primitives = self ._session .primitives ,
199+ bucket_name = bucket_name ,
197200 key_path = key_path ,
198201 header = header ,
199202 names = names ,
@@ -350,20 +353,21 @@ def _find_terminator(body, sep, quoting, quotechar, lineterminator):
350353 raise LineTerminatorNotFound ()
351354 return index
352355
356+ @staticmethod
353357 def _read_csv_once (
354- self ,
355- bucket_name ,
356- key_path ,
357- header = "infer" ,
358+ session_primitives : "SessionPrimitives" ,
359+ bucket_name : str ,
360+ key_path : str ,
361+ header : Optional [ str ] = "infer" ,
358362 names = None ,
359363 usecols = None ,
360364 dtype = None ,
361- sep = "," ,
365+ sep : str = "," ,
362366 thousands = None ,
363- decimal = "." ,
364- lineterminator = "\n " ,
365- quotechar = '"' ,
366- quoting = 0 ,
367+ decimal : str = "." ,
368+ lineterminator : str = "\n " ,
369+ quotechar : str = '"' ,
370+ quoting : int = 0 ,
367371 escapechar = None ,
368372 parse_dates : Union [bool , Dict , List ] = False ,
369373 infer_datetime_format = False ,
@@ -375,6 +379,7 @@ def _read_csv_once(
375379 Try to mimic as most as possible pandas.read_csv()
376380 https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
377381
382+ :param session_primitives: SessionPrimitives()
378383 :param bucket_name: S3 bucket name
379384 :param key_path: S3 key path (W/o bucket)
380385 :param header: Same as pandas.read_csv()
@@ -395,7 +400,9 @@ def _read_csv_once(
395400 :return: Pandas Dataframe
396401 """
397402 buff = BytesIO ()
398- self ._client_s3 .download_fileobj (Bucket = bucket_name , Key = key_path , Fileobj = buff )
403+ session : Session = session_primitives .session
404+ client_s3 = session .boto3_session .client (service_name = "s3" , use_ssl = True , config = session .botocore_config )
405+ client_s3 .download_fileobj (Bucket = bucket_name , Key = key_path , Fileobj = buff )
399406 buff .seek (0 ),
400407 dataframe = pd .read_csv (
401408 buff ,
@@ -418,6 +425,47 @@ def _read_csv_once(
418425 buff .close ()
419426 return dataframe
420427
428+ @staticmethod
429+ def _read_csv_once_remote (send_pipe : mp .connection .Connection ,
430+ session_primitives : "SessionPrimitives" ,
431+ bucket_name : str ,
432+ key_path : str ,
433+ header : str = "infer" ,
434+ names = None ,
435+ usecols = None ,
436+ dtype = None ,
437+ sep : str = "," ,
438+ thousands = None ,
439+ decimal : str = "." ,
440+ lineterminator : str = "\n " ,
441+ quotechar : str = '"' ,
442+ quoting : int = 0 ,
443+ escapechar = None ,
444+ parse_dates : Union [bool , Dict , List ] = False ,
445+ infer_datetime_format = False ,
446+ encoding = None ,
447+ converters = None ):
448+ df : pd .DataFrame = Pandas ._read_csv_once (session_primitives = session_primitives ,
449+ bucket_name = bucket_name ,
450+ key_path = key_path ,
451+ header = header ,
452+ names = names ,
453+ usecols = usecols ,
454+ dtype = dtype ,
455+ sep = sep ,
456+ thousands = thousands ,
457+ decimal = decimal ,
458+ lineterminator = lineterminator ,
459+ quotechar = quotechar ,
460+ quoting = quoting ,
461+ escapechar = escapechar ,
462+ parse_dates = parse_dates ,
463+ infer_datetime_format = infer_datetime_format ,
464+ encoding = encoding ,
465+ converters = converters )
466+ send_pipe .send (df )
467+ send_pipe .close ()
468+
421469 @staticmethod
422470 def _list_parser (value : str ) -> List [Union [int , float , str , None ]]:
423471 # try resolve with a simple literal_eval
@@ -1164,7 +1212,7 @@ def to_redshift(
11641212
11651213 :param dataframe: Pandas Dataframe
11661214 :param path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
1167- :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
1215+ :param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
11681216 :param schema: The Redshift Schema for the table
11691217 :param table: The name of the desired Redshift table
11701218 :param iam_role: AWS IAM role with the related permissions
@@ -1190,40 +1238,57 @@ def to_redshift(
11901238 self ._session .s3 .delete_objects (path = path )
11911239 num_rows : int = len (dataframe .index )
11921240 logger .debug (f"Number of rows: { num_rows } " )
1193- if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE :
1194- num_partitions : int = 1
1195- else :
1196- num_slices : int = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
1197- logger .debug (f"Number of slices on Redshift: { num_slices } " )
1198- num_partitions = num_slices
1199- logger .debug (f"Number of partitions calculated: { num_partitions } " )
1200- objects_paths : List [str ] = self .to_parquet (dataframe = dataframe ,
1201- path = path ,
1202- preserve_index = preserve_index ,
1203- mode = "append" ,
1204- procs_cpu_bound = num_partitions ,
1205- cast_columns = cast_columns_parquet )
1206- manifest_path : str = f"{ path } manifest.json"
1207- self ._session .redshift .write_load_manifest (manifest_path = manifest_path , objects_paths = objects_paths )
1208- self ._session .redshift .load_table (
1209- dataframe = dataframe ,
1210- dataframe_type = "pandas" ,
1211- manifest_path = manifest_path ,
1212- schema_name = schema ,
1213- table_name = table ,
1214- redshift_conn = connection ,
1215- preserve_index = preserve_index ,
1216- num_files = num_partitions ,
1217- iam_role = iam_role ,
1218- diststyle = diststyle ,
1219- distkey = distkey ,
1220- sortstyle = sortstyle ,
1221- sortkey = sortkey ,
1222- primary_keys = primary_keys ,
1223- mode = mode ,
1224- cast_columns = cast_columns ,
1225- )
1226- self ._session .s3 .delete_objects (path = path )
1241+
1242+ generated_conn : bool = False
1243+ if type (connection ) == str :
1244+ logger .debug ("Glue connection (str) provided." )
1245+ connection = self ._session .glue .get_connection (name = connection )
1246+ generated_conn = True
1247+
1248+ try :
1249+
1250+ if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE :
1251+ num_partitions : int = 1
1252+ else :
1253+ num_slices : int = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
1254+ logger .debug (f"Number of slices on Redshift: { num_slices } " )
1255+ num_partitions = num_slices
1256+ logger .debug (f"Number of partitions calculated: { num_partitions } " )
1257+ objects_paths : List [str ] = self .to_parquet (dataframe = dataframe ,
1258+ path = path ,
1259+ preserve_index = preserve_index ,
1260+ mode = "append" ,
1261+ procs_cpu_bound = num_partitions ,
1262+ cast_columns = cast_columns_parquet )
1263+ manifest_path : str = f"{ path } manifest.json"
1264+ self ._session .redshift .write_load_manifest (manifest_path = manifest_path , objects_paths = objects_paths )
1265+ self ._session .redshift .load_table (
1266+ dataframe = dataframe ,
1267+ dataframe_type = "pandas" ,
1268+ manifest_path = manifest_path ,
1269+ schema_name = schema ,
1270+ table_name = table ,
1271+ redshift_conn = connection ,
1272+ preserve_index = preserve_index ,
1273+ num_files = num_partitions ,
1274+ iam_role = iam_role ,
1275+ diststyle = diststyle ,
1276+ distkey = distkey ,
1277+ sortstyle = sortstyle ,
1278+ sortkey = sortkey ,
1279+ primary_keys = primary_keys ,
1280+ mode = mode ,
1281+ cast_columns = cast_columns ,
1282+ )
1283+ self ._session .s3 .delete_objects (path = path )
1284+
1285+ except Exception as ex :
1286+ connection .rollback ()
1287+ if generated_conn is True :
1288+ connection .close ()
1289+ raise ex
1290+ if generated_conn is True :
1291+ connection .close ()
12271292
12281293 def read_log_query (self ,
12291294 query ,
@@ -1346,7 +1411,7 @@ def read_parquet(self,
13461411
13471412 @staticmethod
13481413 def _read_parquet_paths_remote (send_pipe : mp .connection .Connection ,
1349- session_primitives : Any ,
1414+ session_primitives : "SessionPrimitives" ,
13501415 path : Union [str , List [str ]],
13511416 columns : Optional [List [str ]] = None ,
13521417 filters : Optional [Union [List [Tuple [Any ]], List [List [Tuple [Any ]]]]] = None ,
@@ -1364,7 +1429,7 @@ def _read_parquet_paths_remote(send_pipe: mp.connection.Connection,
13641429 send_pipe .close ()
13651430
13661431 @staticmethod
1367- def _read_parquet_paths (session_primitives : Any ,
1432+ def _read_parquet_paths (session_primitives : "SessionPrimitives" ,
13681433 path : Union [str , List [str ]],
13691434 columns : Optional [List [str ]] = None ,
13701435 filters : Optional [Union [List [Tuple [Any ]], List [List [Tuple [Any ]]]]] = None ,
@@ -1694,6 +1759,7 @@ def read_csv_list(
16941759 infer_datetime_format = False ,
16951760 encoding = "utf-8" ,
16961761 converters = None ,
1762+ procs_cpu_bound : Optional [int ] = None ,
16971763 ) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
16981764 """
16991765 Read CSV files from AWS S3 using optimized strategies.
@@ -1718,6 +1784,7 @@ def read_csv_list(
17181784 :param infer_datetime_format: Same as pandas.read_csv()
17191785 :param encoding: Same as pandas.read_csv()
17201786 :param converters: Same as pandas.read_csv()
1787+ :param procs_cpu_bound: Number of cores used for CPU bound tasks
17211788 :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
17221789 """
17231790 if max_result_size is not None :
@@ -1739,11 +1806,16 @@ def read_csv_list(
17391806 encoding = encoding ,
17401807 converters = converters )
17411808 else :
1742- df_full : Optional [pd .DataFrame ] = None
1743- for path in paths :
1744- logger .debug (f"path: { path } " )
1809+ procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else self ._session .procs_cpu_bound if self ._session .procs_cpu_bound is not None else 1
1810+ logger .debug (f"procs_cpu_bound: { procs_cpu_bound } " )
1811+ df : Optional [pd .DataFrame ] = None
1812+ session_primitives = self ._session .primitives
1813+ if len (paths ) == 1 :
1814+ path = paths [0 ]
17451815 bucket_name , key_path = Pandas ._parse_path (path )
1746- df = self ._read_csv_once (bucket_name = bucket_name ,
1816+ logger .debug (f"path: { path } " )
1817+ df = self ._read_csv_once (session_primitives = self ._session .primitives ,
1818+ bucket_name = bucket_name ,
17471819 key_path = key_path ,
17481820 header = header ,
17491821 names = names ,
@@ -1760,11 +1832,37 @@ def read_csv_list(
17601832 infer_datetime_format = infer_datetime_format ,
17611833 encoding = encoding ,
17621834 converters = converters )
1763- if df_full is None :
1764- df_full = df
1765- else :
1766- df_full = pd .concat (objs = [df_full , df ], ignore_index = True )
1767- return df_full
1835+ else :
1836+ procs = []
1837+ receive_pipes = []
1838+ logger .debug (f"len(paths): { len (paths )} " )
1839+ for path in paths :
1840+ receive_pipe , send_pipe = mp .Pipe ()
1841+ bucket_name , key_path = Pandas ._parse_path (path )
1842+ logger .debug (f"launching path: { path } " )
1843+ proc = mp .Process (
1844+ target = self ._read_csv_once_remote ,
1845+ args = (send_pipe , session_primitives , bucket_name , key_path , header , names , usecols , dtype , sep ,
1846+ thousands , decimal , lineterminator , quotechar , quoting , escapechar , parse_dates ,
1847+ infer_datetime_format , encoding , converters ),
1848+ )
1849+ proc .daemon = False
1850+ proc .start ()
1851+ procs .append (proc )
1852+ receive_pipes .append (receive_pipe )
1853+ utils .wait_process_release (processes = procs , target_number = procs_cpu_bound )
1854+ for i in range (len (procs )):
1855+ logger .debug (f"Waiting pipe number: { i } " )
1856+ df_received = receive_pipes [i ].recv ()
1857+ if df is None :
1858+ df = df_received
1859+ else :
1860+ df = pd .concat (objs = [df , df_received ], ignore_index = True )
1861+ logger .debug (f"Waiting proc number: { i } " )
1862+ procs [i ].join ()
1863+ logger .debug (f"Closing proc number: { i } " )
1864+ receive_pipes [i ].close ()
1865+ return df
17681866
17691867 def _read_csv_list_iterator (
17701868 self ,
@@ -1852,6 +1950,7 @@ def read_csv_prefix(
18521950 infer_datetime_format = False ,
18531951 encoding = "utf-8" ,
18541952 converters = None ,
1953+ procs_cpu_bound : Optional [int ] = None ,
18551954 ) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
18561955 """
18571956 Read CSV files from AWS S3 PREFIX using optimized strategies.
@@ -1876,6 +1975,7 @@ def read_csv_prefix(
18761975 :param infer_datetime_format: Same as pandas.read_csv()
18771976 :param encoding: Same as pandas.read_csv()
18781977 :param converters: Same as pandas.read_csv()
1978+ :param procs_cpu_bound: Number of cores used for CPU bound tasks
18791979 :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
18801980 """
18811981 paths : List [str ] = self ._session .s3 .list_objects (path = path_prefix )
0 commit comments