Skip to content

Commit eb04891

Browse files
committed
Parallel Pandas.read_csv_list and Pandas.to_redshift with Glue Connection
1 parent c6bfaeb commit eb04891

File tree

5 files changed

+231
-69
lines changed

5 files changed

+231
-69
lines changed

awswrangler/pandas.py

Lines changed: 158 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from s3fs import S3FileSystem # type: ignore
1919

2020
from awswrangler import data_types
21+
from awswrangler import utils
2122
from awswrangler.exceptions import (UnsupportedWriteMode, UnsupportedFileFormat, AthenaQueryError, EmptyS3Object,
2223
LineTerminatorNotFound, EmptyDataframe, InvalidSerDe, InvalidCompression,
2324
InvalidParameters, InvalidEngine)
@@ -122,7 +123,8 @@ def read_csv(
122123
encoding=encoding,
123124
converters=converters)
124125
else:
125-
ret = self._read_csv_once(bucket_name=bucket_name,
126+
ret = self._read_csv_once(session_primitives=self._session.primitives,
127+
bucket_name=bucket_name,
126128
key_path=key_path,
127129
header=header,
128130
names=names,
@@ -193,7 +195,8 @@ def _read_csv_iterator(
193195
if total_size <= 0:
194196
raise EmptyS3Object(metadata)
195197
elif total_size <= max_result_size:
196-
yield self._read_csv_once(bucket_name=bucket_name,
198+
yield self._read_csv_once(session_primitives=self._session.primitives,
199+
bucket_name=bucket_name,
197200
key_path=key_path,
198201
header=header,
199202
names=names,
@@ -350,20 +353,21 @@ def _find_terminator(body, sep, quoting, quotechar, lineterminator):
350353
raise LineTerminatorNotFound()
351354
return index
352355

356+
@staticmethod
353357
def _read_csv_once(
354-
self,
355-
bucket_name,
356-
key_path,
357-
header="infer",
358+
session_primitives: "SessionPrimitives",
359+
bucket_name: str,
360+
key_path: str,
361+
header: Optional[str] = "infer",
358362
names=None,
359363
usecols=None,
360364
dtype=None,
361-
sep=",",
365+
sep: str = ",",
362366
thousands=None,
363-
decimal=".",
364-
lineterminator="\n",
365-
quotechar='"',
366-
quoting=0,
367+
decimal: str = ".",
368+
lineterminator: str = "\n",
369+
quotechar: str = '"',
370+
quoting: int = 0,
367371
escapechar=None,
368372
parse_dates: Union[bool, Dict, List] = False,
369373
infer_datetime_format=False,
@@ -375,6 +379,7 @@ def _read_csv_once(
375379
Try to mimic as most as possible pandas.read_csv()
376380
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
377381
382+
:param session_primitives: SessionPrimitives()
378383
:param bucket_name: S3 bucket name
379384
:param key_path: S3 key path (W/o bucket)
380385
:param header: Same as pandas.read_csv()
@@ -395,7 +400,9 @@ def _read_csv_once(
395400
:return: Pandas Dataframe
396401
"""
397402
buff = BytesIO()
398-
self._client_s3.download_fileobj(Bucket=bucket_name, Key=key_path, Fileobj=buff)
403+
session: Session = session_primitives.session
404+
client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
405+
client_s3.download_fileobj(Bucket=bucket_name, Key=key_path, Fileobj=buff)
399406
buff.seek(0),
400407
dataframe = pd.read_csv(
401408
buff,
@@ -418,6 +425,47 @@ def _read_csv_once(
418425
buff.close()
419426
return dataframe
420427

428+
@staticmethod
429+
def _read_csv_once_remote(send_pipe: mp.connection.Connection,
430+
session_primitives: "SessionPrimitives",
431+
bucket_name: str,
432+
key_path: str,
433+
header: str = "infer",
434+
names=None,
435+
usecols=None,
436+
dtype=None,
437+
sep: str = ",",
438+
thousands=None,
439+
decimal: str = ".",
440+
lineterminator: str = "\n",
441+
quotechar: str = '"',
442+
quoting: int = 0,
443+
escapechar=None,
444+
parse_dates: Union[bool, Dict, List] = False,
445+
infer_datetime_format=False,
446+
encoding=None,
447+
converters=None):
448+
df: pd.DataFrame = Pandas._read_csv_once(session_primitives=session_primitives,
449+
bucket_name=bucket_name,
450+
key_path=key_path,
451+
header=header,
452+
names=names,
453+
usecols=usecols,
454+
dtype=dtype,
455+
sep=sep,
456+
thousands=thousands,
457+
decimal=decimal,
458+
lineterminator=lineterminator,
459+
quotechar=quotechar,
460+
quoting=quoting,
461+
escapechar=escapechar,
462+
parse_dates=parse_dates,
463+
infer_datetime_format=infer_datetime_format,
464+
encoding=encoding,
465+
converters=converters)
466+
send_pipe.send(df)
467+
send_pipe.close()
468+
421469
@staticmethod
422470
def _list_parser(value: str) -> List[Union[int, float, str, None]]:
423471
# try resolve with a simple literal_eval
@@ -1164,7 +1212,7 @@ def to_redshift(
11641212
11651213
:param dataframe: Pandas Dataframe
11661214
:param path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
1167-
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
1215+
:param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
11681216
:param schema: The Redshift Schema for the table
11691217
:param table: The name of the desired Redshift table
11701218
:param iam_role: AWS IAM role with the related permissions
@@ -1190,40 +1238,57 @@ def to_redshift(
11901238
self._session.s3.delete_objects(path=path)
11911239
num_rows: int = len(dataframe.index)
11921240
logger.debug(f"Number of rows: {num_rows}")
1193-
if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE:
1194-
num_partitions: int = 1
1195-
else:
1196-
num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection)
1197-
logger.debug(f"Number of slices on Redshift: {num_slices}")
1198-
num_partitions = num_slices
1199-
logger.debug(f"Number of partitions calculated: {num_partitions}")
1200-
objects_paths: List[str] = self.to_parquet(dataframe=dataframe,
1201-
path=path,
1202-
preserve_index=preserve_index,
1203-
mode="append",
1204-
procs_cpu_bound=num_partitions,
1205-
cast_columns=cast_columns_parquet)
1206-
manifest_path: str = f"{path}manifest.json"
1207-
self._session.redshift.write_load_manifest(manifest_path=manifest_path, objects_paths=objects_paths)
1208-
self._session.redshift.load_table(
1209-
dataframe=dataframe,
1210-
dataframe_type="pandas",
1211-
manifest_path=manifest_path,
1212-
schema_name=schema,
1213-
table_name=table,
1214-
redshift_conn=connection,
1215-
preserve_index=preserve_index,
1216-
num_files=num_partitions,
1217-
iam_role=iam_role,
1218-
diststyle=diststyle,
1219-
distkey=distkey,
1220-
sortstyle=sortstyle,
1221-
sortkey=sortkey,
1222-
primary_keys=primary_keys,
1223-
mode=mode,
1224-
cast_columns=cast_columns,
1225-
)
1226-
self._session.s3.delete_objects(path=path)
1241+
1242+
generated_conn: bool = False
1243+
if type(connection) == str:
1244+
logger.debug("Glue connection (str) provided.")
1245+
connection = self._session.glue.get_connection(name=connection)
1246+
generated_conn = True
1247+
1248+
try:
1249+
1250+
if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE:
1251+
num_partitions: int = 1
1252+
else:
1253+
num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection)
1254+
logger.debug(f"Number of slices on Redshift: {num_slices}")
1255+
num_partitions = num_slices
1256+
logger.debug(f"Number of partitions calculated: {num_partitions}")
1257+
objects_paths: List[str] = self.to_parquet(dataframe=dataframe,
1258+
path=path,
1259+
preserve_index=preserve_index,
1260+
mode="append",
1261+
procs_cpu_bound=num_partitions,
1262+
cast_columns=cast_columns_parquet)
1263+
manifest_path: str = f"{path}manifest.json"
1264+
self._session.redshift.write_load_manifest(manifest_path=manifest_path, objects_paths=objects_paths)
1265+
self._session.redshift.load_table(
1266+
dataframe=dataframe,
1267+
dataframe_type="pandas",
1268+
manifest_path=manifest_path,
1269+
schema_name=schema,
1270+
table_name=table,
1271+
redshift_conn=connection,
1272+
preserve_index=preserve_index,
1273+
num_files=num_partitions,
1274+
iam_role=iam_role,
1275+
diststyle=diststyle,
1276+
distkey=distkey,
1277+
sortstyle=sortstyle,
1278+
sortkey=sortkey,
1279+
primary_keys=primary_keys,
1280+
mode=mode,
1281+
cast_columns=cast_columns,
1282+
)
1283+
self._session.s3.delete_objects(path=path)
1284+
1285+
except Exception as ex:
1286+
connection.rollback()
1287+
if generated_conn is True:
1288+
connection.close()
1289+
raise ex
1290+
if generated_conn is True:
1291+
connection.close()
12271292

12281293
def read_log_query(self,
12291294
query,
@@ -1346,7 +1411,7 @@ def read_parquet(self,
13461411

13471412
@staticmethod
13481413
def _read_parquet_paths_remote(send_pipe: mp.connection.Connection,
1349-
session_primitives: Any,
1414+
session_primitives: "SessionPrimitives",
13501415
path: Union[str, List[str]],
13511416
columns: Optional[List[str]] = None,
13521417
filters: Optional[Union[List[Tuple[Any]], List[List[Tuple[Any]]]]] = None,
@@ -1364,7 +1429,7 @@ def _read_parquet_paths_remote(send_pipe: mp.connection.Connection,
13641429
send_pipe.close()
13651430

13661431
@staticmethod
1367-
def _read_parquet_paths(session_primitives: Any,
1432+
def _read_parquet_paths(session_primitives: "SessionPrimitives",
13681433
path: Union[str, List[str]],
13691434
columns: Optional[List[str]] = None,
13701435
filters: Optional[Union[List[Tuple[Any]], List[List[Tuple[Any]]]]] = None,
@@ -1694,6 +1759,7 @@ def read_csv_list(
16941759
infer_datetime_format=False,
16951760
encoding="utf-8",
16961761
converters=None,
1762+
procs_cpu_bound: Optional[int] = None,
16971763
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
16981764
"""
16991765
Read CSV files from AWS S3 using optimized strategies.
@@ -1718,6 +1784,7 @@ def read_csv_list(
17181784
:param infer_datetime_format: Same as pandas.read_csv()
17191785
:param encoding: Same as pandas.read_csv()
17201786
:param converters: Same as pandas.read_csv()
1787+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
17211788
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
17221789
"""
17231790
if max_result_size is not None:
@@ -1739,11 +1806,16 @@ def read_csv_list(
17391806
encoding=encoding,
17401807
converters=converters)
17411808
else:
1742-
df_full: Optional[pd.DataFrame] = None
1743-
for path in paths:
1744-
logger.debug(f"path: {path}")
1809+
procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else self._session.procs_cpu_bound if self._session.procs_cpu_bound is not None else 1
1810+
logger.debug(f"procs_cpu_bound: {procs_cpu_bound}")
1811+
df: Optional[pd.DataFrame] = None
1812+
session_primitives = self._session.primitives
1813+
if len(paths) == 1:
1814+
path = paths[0]
17451815
bucket_name, key_path = Pandas._parse_path(path)
1746-
df = self._read_csv_once(bucket_name=bucket_name,
1816+
logger.debug(f"path: {path}")
1817+
df = self._read_csv_once(session_primitives=self._session.primitives,
1818+
bucket_name=bucket_name,
17471819
key_path=key_path,
17481820
header=header,
17491821
names=names,
@@ -1760,11 +1832,37 @@ def read_csv_list(
17601832
infer_datetime_format=infer_datetime_format,
17611833
encoding=encoding,
17621834
converters=converters)
1763-
if df_full is None:
1764-
df_full = df
1765-
else:
1766-
df_full = pd.concat(objs=[df_full, df], ignore_index=True)
1767-
return df_full
1835+
else:
1836+
procs = []
1837+
receive_pipes = []
1838+
logger.debug(f"len(paths): {len(paths)}")
1839+
for path in paths:
1840+
receive_pipe, send_pipe = mp.Pipe()
1841+
bucket_name, key_path = Pandas._parse_path(path)
1842+
logger.debug(f"launching path: {path}")
1843+
proc = mp.Process(
1844+
target=self._read_csv_once_remote,
1845+
args=(send_pipe, session_primitives, bucket_name, key_path, header, names, usecols, dtype, sep,
1846+
thousands, decimal, lineterminator, quotechar, quoting, escapechar, parse_dates,
1847+
infer_datetime_format, encoding, converters),
1848+
)
1849+
proc.daemon = False
1850+
proc.start()
1851+
procs.append(proc)
1852+
receive_pipes.append(receive_pipe)
1853+
utils.wait_process_release(processes=procs, target_number=procs_cpu_bound)
1854+
for i in range(len(procs)):
1855+
logger.debug(f"Waiting pipe number: {i}")
1856+
df_received = receive_pipes[i].recv()
1857+
if df is None:
1858+
df = df_received
1859+
else:
1860+
df = pd.concat(objs=[df, df_received], ignore_index=True)
1861+
logger.debug(f"Waiting proc number: {i}")
1862+
procs[i].join()
1863+
logger.debug(f"Closing proc number: {i}")
1864+
receive_pipes[i].close()
1865+
return df
17681866

17691867
def _read_csv_list_iterator(
17701868
self,
@@ -1852,6 +1950,7 @@ def read_csv_prefix(
18521950
infer_datetime_format=False,
18531951
encoding="utf-8",
18541952
converters=None,
1953+
procs_cpu_bound: Optional[int] = None,
18551954
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
18561955
"""
18571956
Read CSV files from AWS S3 PREFIX using optimized strategies.
@@ -1876,6 +1975,7 @@ def read_csv_prefix(
18761975
:param infer_datetime_format: Same as pandas.read_csv()
18771976
:param encoding: Same as pandas.read_csv()
18781977
:param converters: Same as pandas.read_csv()
1978+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
18791979
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
18801980
"""
18811981
paths: List[str] = self._session.s3.list_objects(path=path_prefix)

awswrangler/utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,31 @@ def calculate_bounders(num_items, num_groups=None, max_size=None):
2929
raise InvalidArguments("You must give num_groups or max_size!")
3030

3131

32-
def wait_process_release(processes):
32+
def wait_process_release(processes, target_number=None):
3333
"""
3434
Wait one of the processes releases
3535
:param processes: List of processes
36+
:param target_number: Wait for a target number of running processes
3637
:return: None
3738
"""
3839
n = len(processes)
3940
i = 0
4041
while True:
41-
if not processes[i].is_alive():
42-
del processes[i]
43-
return None
44-
i += 1
45-
if i == n:
46-
i = 0
47-
sleep(0.1)
42+
if target_number is None:
43+
if processes[i].is_alive() is False:
44+
del processes[i]
45+
return None
46+
i += 1
47+
if i == n:
48+
i = 0
49+
else:
50+
count = 0
51+
for p in processes:
52+
if p.is_alive():
53+
count += 1
54+
if count <= target_number:
55+
return count
56+
sleep(0.25)
4857

4958

5059
def lcm(a: int, b: int) -> int:

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mypy~=0.761
33
flake8~=3.7.9
44
pytest-cov~=2.8.1
55
scikit-learn~=0.22.1
6-
cfn-lint~=0.26.2
6+
cfn-lint~=0.26.3
77
twine~=3.1.1
88
wheel~=0.33.6
99
sphinx~=2.3.1

0 commit comments

Comments
 (0)