22
33import pandas
44
5- from pyspark .sql .functions import pandas_udf , PandasUDFType
6- from pyspark .sql .functions import floor , rand
5+ from pyspark .sql .functions import pandas_udf , PandasUDFType , spark_partition_id
76from pyspark .sql .types import TimestampType
87
98from awswrangler .exceptions import MissingBatchDetected , UnsupportedFileFormat
@@ -99,7 +98,7 @@ def to_redshift(
9998 @pandas_udf (returnType = "objects_paths string" ,
10099 functionType = PandasUDFType .GROUPED_MAP )
101100 def write (pandas_dataframe ):
102- del pandas_dataframe ["partition_index " ]
101+ del pandas_dataframe ["aws_data_wrangler_internal_partition_id " ]
103102 paths = session_primitives .session .pandas .to_parquet (
104103 dataframe = pandas_dataframe ,
105104 path = path ,
@@ -109,10 +108,13 @@ def write(pandas_dataframe):
109108 cast_columns = casts )
110109 return pandas .DataFrame .from_dict ({"objects_paths" : paths })
111110
112- df_objects_paths = (dataframe .withColumn (
113- "partition_index" , floor (rand () * num_partitions )).repartition (
114- "partition_index" ).groupby ("partition_index" ).apply (write ))
111+ df_objects_paths = dataframe .repartition (numPartitions = num_partitions ) \
112+ .withColumn ("aws_data_wrangler_internal_partition_id" , spark_partition_id ()) \
113+ .groupby ("aws_data_wrangler_internal_partition_id" ) \
114+ .apply (write )
115+
115116 objects_paths = list (df_objects_paths .toPandas ()["objects_paths" ])
117+ dataframe .unpersist ()
116118 num_files_returned = len (objects_paths )
117119 if num_files_returned != num_partitions :
118120 raise MissingBatchDetected (
@@ -140,7 +142,6 @@ def write(pandas_dataframe):
140142 sortkey = sortkey ,
141143 mode = mode ,
142144 )
143- dataframe .unpersist ()
144145 self ._session .s3 .delete_objects (path = path )
145146
146147 def create_glue_table (self ,
0 commit comments