Skip to content

Commit 7f0b4b6

Browse files
committed
Reducing I/O parallelism for some specific operations
1 parent 96075b9 commit 7f0b4b6

File tree

9 files changed

+146
-75
lines changed

9 files changed

+146
-75
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ session.spark.create_glue_table(dataframe=dataframe,
208208
```py3
209209
session = awswrangler.Session(spark_session=spark)
210210
dfs = session.spark.flatten(dataframe=df_nested)
211-
for name, df_flat in dfs:
211+
for name, df_flat in dfs.items():
212212
print(name)
213213
df_flat.show()
214214
```

awswrangler/pandas.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -885,12 +885,11 @@ def write_csv_dataframe(dataframe, path, preserve_index, compression, fs, extra_
885885
Pandas._write_csv_to_s3_retrying(fs=fs, path=path, buffer=csv_buffer)
886886

887887
@staticmethod
888-
@tenacity.retry(
889-
retry=tenacity.retry_if_exception_type(exception_types=(ClientError, HTTPClientError)),
890-
wait=tenacity.wait_random_exponential(multiplier=0.5, max=10),
891-
stop=tenacity.stop_after_attempt(max_attempt_number=15),
892-
reraise=True,
893-
)
888+
@tenacity.retry(retry=tenacity.retry_if_exception_type(exception_types=(ClientError, HTTPClientError)),
889+
wait=tenacity.wait_random_exponential(multiplier=0.5),
890+
stop=tenacity.stop_after_attempt(max_attempt_number=10),
891+
reraise=True,
892+
after=tenacity.after_log(logger, logging.INFO))
894893
def _write_csv_to_s3_retrying(fs: Any, path: str, buffer: bytes) -> None:
895894
with fs.open(path, "wb") as f:
896895
f.write(buffer)
@@ -931,12 +930,11 @@ def write_parquet_dataframe(dataframe, path, preserve_index, compression, fs, ca
931930
dataframe[col] = dataframe[col].astype("Int64")
932931

933932
@staticmethod
934-
@tenacity.retry(
935-
retry=tenacity.retry_if_exception_type(exception_types=[ClientError, HTTPClientError]),
936-
wait=tenacity.wait_random_exponential(multiplier=0.5, max=10),
937-
stop=tenacity.stop_after_attempt(max_attempt_number=15),
938-
reraise=True,
939-
)
933+
@tenacity.retry(retry=tenacity.retry_if_exception_type(exception_types=(ClientError, HTTPClientError)),
934+
wait=tenacity.wait_random_exponential(multiplier=0.5),
935+
stop=tenacity.stop_after_attempt(max_attempt_number=10),
936+
reraise=True,
937+
after=tenacity.after_log(logger, logging.INFO))
940938
def _write_parquet_to_s3_retrying(fs: Any, path: str, table: pa.Table, compression: str) -> None:
941939
with fs.open(path, "wb") as f:
942940
pq.write_table(table, f, compression=compression, coerce_timestamps="ms", flavor="spark")
@@ -1066,5 +1064,7 @@ def drop_duplicated_columns(dataframe: pd.DataFrame, inplace: bool = True) -> pd
10661064
if inplace is False:
10671065
dataframe = dataframe.copy(deep=True)
10681066
duplicated_cols = dataframe.columns.duplicated()
1069-
logger.warning(f"Dropping repeated columns: {list(dataframe.columns[duplicated_cols])}")
1067+
duplicated_cols_names = list(dataframe.columns[duplicated_cols])
1068+
if len(duplicated_cols_names) > 0:
1069+
logger.warning(f"Dropping repeated columns: {duplicated_cols_names}")
10701070
return dataframe.loc[:, ~duplicated_cols]

awswrangler/redshift.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Dict, List, Union, Optional
12
import json
23
import logging
34

@@ -116,14 +117,25 @@ def get_connection(self, glue_connection):
116117
conn = self.generate_connection(database=database, host=host, port=int(port), user=user, password=password)
117118
return conn
118119

119-
def write_load_manifest(self, manifest_path, objects_paths):
120-
objects_sizes = self._session.s3.get_objects_sizes(objects_paths=objects_paths)
121-
manifest = {"entries": []}
120+
def write_load_manifest(self, manifest_path: str, objects_paths: List[str], procs_io_bound: Optional[int] = None
121+
) -> Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]]:
122+
objects_sizes: Dict[str, int] = self._session.s3.get_objects_sizes(objects_paths=objects_paths,
123+
procs_io_bound=procs_io_bound)
124+
manifest: Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]] = {"entries": []}
125+
path: str
126+
size: int
122127
for path, size in objects_sizes.items():
123-
entry = {"url": path, "mandatory": True, "meta": {"content_length": size}}
124-
manifest.get("entries").append(entry)
125-
payload = json.dumps(manifest)
128+
entry: Dict[str, Union[str, bool, Dict[str, int]]] = {
129+
"url": path,
130+
"mandatory": True,
131+
"meta": {
132+
"content_length": size
133+
}
134+
}
135+
manifest["entries"].append(entry)
136+
payload: str = json.dumps(manifest)
126137
client_s3 = self._session.boto3_session.client(service_name="s3", config=self._session.botocore_config)
138+
bucket: str
127139
bucket, path = manifest_path.replace("s3://", "").split("/", 1)
128140
client_s3.put_object(Body=payload, Bucket=bucket, Key=path)
129141
return manifest

awswrangler/s3.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Dict, List, Optional
12
import multiprocessing as mp
23
from math import ceil
34
import logging
@@ -202,12 +203,11 @@ def list_objects(self, path):
202203
return keys
203204

204205
@staticmethod
205-
@tenacity.retry(
206-
retry=tenacity.retry_if_exception_type(exception_types=(ClientError, HTTPClientError)),
207-
wait=tenacity.wait_random_exponential(multiplier=0.5, max=10),
208-
stop=tenacity.stop_after_attempt(max_attempt_number=15),
209-
reraise=True,
210-
)
206+
@tenacity.retry(retry=tenacity.retry_if_exception_type(exception_types=(ClientError, HTTPClientError)),
207+
wait=tenacity.wait_random_exponential(multiplier=0.5),
208+
stop=tenacity.stop_after_attempt(max_attempt_number=10),
209+
reraise=True,
210+
after=tenacity.after_log(logger, logging.INFO))
211211
def head_object_with_retry(client, bucket, key):
212212
return client.head_object(Bucket=bucket, Key=key)
213213

@@ -226,11 +226,11 @@ def _get_objects_head_remote(send_pipe, session_primitives, objects_paths):
226226
send_pipe.send(objects_sizes)
227227
send_pipe.close()
228228

229-
def get_objects_sizes(self, objects_paths, procs_io_bound=None):
229+
def get_objects_sizes(self, objects_paths: List[str], procs_io_bound: Optional[int] = None) -> Dict[str, int]:
230230
if not procs_io_bound:
231231
procs_io_bound = self._session.procs_io_bound
232232
logger.debug(f"procs_io_bound: {procs_io_bound}")
233-
objects_sizes = {}
233+
objects_sizes: Dict[str, int] = {}
234234
procs = []
235235
receive_pipes = []
236236
bounders = calculate_bounders(len(objects_paths), procs_io_bound)

awswrangler/session.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def _load_new_boto3_session(self):
107107
if self.aws_access_key_id and self.aws_secret_access_key:
108108
args["aws_access_key_id"] = self.aws_access_key_id
109109
args["aws_secret_access_key"] = self.aws_secret_access_key
110-
111110
self._boto3_session = boto3.Session(**args)
112-
113111
self._profile_name = self._boto3_session.profile_name
114112
self._aws_access_key_id = self._boto3_session.get_credentials().access_key
115113
self._aws_secret_access_key = self._boto3_session.get_credentials().secret_key

awswrangler/spark.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import List, Tuple, Dict
1+
from typing import List, Tuple, Dict, Any
22
import logging
33
import os
44

55
import pandas as pd # type: ignore
66

77
from pyspark.sql.functions import pandas_udf, PandasUDFType, spark_partition_id
88
from pyspark.sql.types import TimestampType
9-
from pyspark.sql import DataFrame
9+
from pyspark.sql import DataFrame, SparkSession
1010

1111
from awswrangler.exceptions import MissingBatchDetected, UnsupportedFileFormat
1212

@@ -18,14 +18,22 @@
1818
class Spark:
1919
def __init__(self, session):
2020
self._session = session
21+
cpus: int = os.cpu_count()
22+
if cpus == 1:
23+
self._procs_io_bound: int = 1
24+
else:
25+
self._procs_io_bound = int(cpus / 2)
26+
logging.info(f"_procs_io_bound: {self._procs_io_bound}")
2127

22-
def read_csv(self, **args):
23-
spark = self._session.spark_session
28+
def read_csv(self, **args) -> DataFrame:
29+
spark: SparkSession = self._session.spark_session
2430
return spark.read.csv(**args)
2531

2632
@staticmethod
27-
def _extract_casts(dtypes):
28-
casts = {}
33+
def _extract_casts(dtypes: List[Tuple[str, str]]) -> Dict[str, str]:
34+
casts: Dict[str, str] = {}
35+
name: str
36+
dtype: str
2937
for name, dtype in dtypes:
3038
if dtype in ["smallint", "int", "bigint"]:
3139
casts[name] = "bigint"
@@ -35,7 +43,9 @@ def _extract_casts(dtypes):
3543
return casts
3644

3745
@staticmethod
38-
def date2timestamp(dataframe):
46+
def date2timestamp(dataframe: DataFrame) -> DataFrame:
47+
name: str
48+
dtype: str
3949
for name, dtype in dataframe.dtypes:
4050
if dtype == "date":
4151
dataframe = dataframe.withColumn(name, dataframe[name].cast(TimestampType()))
@@ -44,19 +54,19 @@ def date2timestamp(dataframe):
4454

4555
def to_redshift(
4656
self,
47-
dataframe,
48-
path,
49-
connection,
50-
schema,
51-
table,
52-
iam_role,
53-
diststyle="AUTO",
57+
dataframe: DataFrame,
58+
path: str,
59+
connection: Any,
60+
schema: str,
61+
table: str,
62+
iam_role: str,
63+
diststyle: str = "AUTO",
5464
distkey=None,
55-
sortstyle="COMPOUND",
65+
sortstyle: str = "COMPOUND",
5666
sortkey=None,
57-
min_num_partitions=200,
58-
mode="append",
59-
):
67+
min_num_partitions: int = 200,
68+
mode: str = "append",
69+
) -> None:
6070
"""
6171
Load Spark Dataframe as a Table on Amazon Redshift
6272
@@ -78,54 +88,58 @@ def to_redshift(
7888
if path[-1] != "/":
7989
path += "/"
8090
self._session.s3.delete_objects(path=path)
81-
spark = self._session.spark_session
82-
casts = Spark._extract_casts(dataframe.dtypes)
91+
spark: SparkSession = self._session.spark_session
92+
casts: Dict[str, str] = Spark._extract_casts(dataframe.dtypes)
8393
dataframe = Spark.date2timestamp(dataframe)
8494
dataframe.cache()
85-
num_rows = dataframe.count()
95+
num_rows: int = dataframe.count()
8696
logger.info(f"Number of rows: {num_rows}")
97+
num_partitions: int
8798
if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE:
8899
num_partitions = 1
89100
else:
90-
num_slices = self._session.redshift.get_number_of_slices(redshift_conn=connection)
101+
num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection)
91102
logger.debug(f"Number of slices on Redshift: {num_slices}")
92103
num_partitions = num_slices
93104
while num_partitions < min_num_partitions:
94105
num_partitions += num_slices
95106
logger.debug(f"Number of partitions calculated: {num_partitions}")
96107
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
97108
session_primitives = self._session.primitives
109+
par_col_name: str = "aws_data_wrangler_internal_partition_id"
98110

99111
@pandas_udf(returnType="objects_paths string", functionType=PandasUDFType.GROUPED_MAP)
100-
def write(pandas_dataframe):
112+
def write(pandas_dataframe: pd.DataFrame) -> pd.DataFrame:
101113
# Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for
102114
# a temporary workaround while waiting for Apache Arrow updates
103115
# https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0
104116
os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1"
105117

106-
del pandas_dataframe["aws_data_wrangler_internal_partition_id"]
107-
paths = session_primitives.session.pandas.to_parquet(dataframe=pandas_dataframe,
108-
path=path,
109-
preserve_index=False,
110-
mode="append",
111-
procs_cpu_bound=1,
112-
cast_columns=casts)
118+
del pandas_dataframe[par_col_name]
119+
paths: List[str] = session_primitives.session.pandas.to_parquet(dataframe=pandas_dataframe,
120+
path=path,
121+
preserve_index=False,
122+
mode="append",
123+
procs_cpu_bound=1,
124+
procs_io_bound=1,
125+
cast_columns=casts)
113126
return pd.DataFrame.from_dict({"objects_paths": paths})
114127

115-
df_objects_paths = dataframe.repartition(numPartitions=num_partitions) \
116-
.withColumn("aws_data_wrangler_internal_partition_id", spark_partition_id()) \
117-
.groupby("aws_data_wrangler_internal_partition_id") \
118-
.apply(write)
128+
df_objects_paths = dataframe.repartition(numPartitions=num_partitions) # type: ignore
129+
df_objects_paths = df_objects_paths.withColumn(par_col_name, spark_partition_id()) # type: ignore
130+
df_objects_paths = df_objects_paths.groupby(par_col_name).apply(write) # type: ignore
119131

120-
objects_paths = list(df_objects_paths.toPandas()["objects_paths"])
132+
objects_paths: List[str] = list(df_objects_paths.toPandas()["objects_paths"])
121133
dataframe.unpersist()
122-
num_files_returned = len(objects_paths)
134+
num_files_returned: int = len(objects_paths)
123135
if num_files_returned != num_partitions:
124136
raise MissingBatchDetected(f"{num_files_returned} files returned. {num_partitions} expected.")
125137
logger.debug(f"List of objects returned: {objects_paths}")
126138
logger.debug(f"Number of objects returned from UDF: {num_files_returned}")
127-
manifest_path = f"{path}manifest.json"
128-
self._session.redshift.write_load_manifest(manifest_path=manifest_path, objects_paths=objects_paths)
139+
manifest_path: str = f"{path}manifest.json"
140+
self._session.redshift.write_load_manifest(manifest_path=manifest_path,
141+
objects_paths=objects_paths,
142+
procs_io_bound=self._procs_io_bound)
129143
self._session.redshift.load_table(dataframe=dataframe,
130144
dataframe_type="spark",
131145
manifest_path=manifest_path,

docs/source/examples.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ Flatten nested PySpark DataFrame
164164
165165
session = awswrangler.Session(spark_session=spark)
166166
dfs = session.spark.flatten(dataframe=df_nested)
167-
for name, df_flat in dfs:
167+
for name, df_flat in dfs.items():
168168
print(name)
169169
df_flat.show()
170170

testing/test_awswrangler/test_redshift.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,36 @@ def test_to_redshift_spark_big(session, bucket, redshift_parameters):
280280
assert len(list(dataframe.columns)) == len(list(rows[0]))
281281

282282

283+
def test_stress_to_redshift_spark_big(session, bucket, redshift_parameters):
284+
dataframe = session.spark_session.createDataFrame(
285+
pd.DataFrame({
286+
"A": list(range(1_000_000)),
287+
"B": list(range(1_000_000)),
288+
"C": list(range(1_000_000))
289+
}))
290+
291+
for i in range(10):
292+
print(i)
293+
con = Redshift.generate_connection(
294+
database="test",
295+
host=redshift_parameters.get("RedshiftAddress"),
296+
port=redshift_parameters.get("RedshiftPort"),
297+
user="test",
298+
password=redshift_parameters.get("RedshiftPassword"),
299+
)
300+
session.spark.to_redshift(
301+
dataframe=dataframe,
302+
path=f"s3://{bucket}/redshift-load/",
303+
connection=con,
304+
schema="public",
305+
table="test",
306+
iam_role=redshift_parameters.get("RedshiftRole"),
307+
mode="overwrite",
308+
min_num_partitions=4,
309+
)
310+
con.close()
311+
312+
283313
@pytest.mark.parametrize(
284314
"sample_name,mode,factor,diststyle,distkey,exc,sortstyle,sortkey",
285315
[

0 commit comments

Comments
 (0)