Skip to content

Commit 561f139

Browse files
committed
Improving Spark.to_redshift() partitioning strategy
1 parent 9cf0eff commit 561f139

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

awswrangler/spark.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import pandas
44

5-
from pyspark.sql.functions import pandas_udf, PandasUDFType
6-
from pyspark.sql.functions import floor, rand
5+
from pyspark.sql.functions import pandas_udf, PandasUDFType, spark_partition_id
76
from pyspark.sql.types import TimestampType
87

98
from awswrangler.exceptions import MissingBatchDetected, UnsupportedFileFormat
@@ -99,7 +98,7 @@ def to_redshift(
9998
@pandas_udf(returnType="objects_paths string",
10099
functionType=PandasUDFType.GROUPED_MAP)
101100
def write(pandas_dataframe):
102-
del pandas_dataframe["partition_index"]
101+
del pandas_dataframe["aws_data_wrangler_internal_partition_id"]
103102
paths = session_primitives.session.pandas.to_parquet(
104103
dataframe=pandas_dataframe,
105104
path=path,
@@ -109,10 +108,13 @@ def write(pandas_dataframe):
109108
cast_columns=casts)
110109
return pandas.DataFrame.from_dict({"objects_paths": paths})
111110

112-
df_objects_paths = (dataframe.withColumn(
113-
"partition_index", floor(rand() * num_partitions)).repartition(
114-
"partition_index").groupby("partition_index").apply(write))
111+
df_objects_paths = dataframe.repartition(numPartitions=num_partitions) \
112+
.withColumn("aws_data_wrangler_internal_partition_id", spark_partition_id()) \
113+
.groupby("aws_data_wrangler_internal_partition_id") \
114+
.apply(write)
115+
115116
objects_paths = list(df_objects_paths.toPandas()["objects_paths"])
117+
dataframe.unpersist()
116118
num_files_returned = len(objects_paths)
117119
if num_files_returned != num_partitions:
118120
raise MissingBatchDetected(
@@ -140,7 +142,6 @@ def write(pandas_dataframe):
140142
sortkey=sortkey,
141143
mode=mode,
142144
)
143-
dataframe.unpersist()
144145
self._session.s3.delete_objects(path=path)
145146

146147
def create_glue_table(self,

0 commit comments

Comments
 (0)