1- from typing import List , Tuple , Dict
1+ from typing import List , Tuple , Dict , Any
22import logging
33import os
44
55import pandas as pd # type: ignore
66
77from pyspark .sql .functions import pandas_udf , PandasUDFType , spark_partition_id
88from pyspark .sql .types import TimestampType
9- from pyspark .sql import DataFrame
9+ from pyspark .sql import DataFrame , SparkSession
1010
1111from awswrangler .exceptions import MissingBatchDetected , UnsupportedFileFormat
1212
1818class Spark :
1919 def __init__ (self , session ):
2020 self ._session = session
21+ cpus : int = os .cpu_count ()
22+ if cpus == 1 :
23+ self ._procs_io_bound : int = 1
24+ else :
25+ self ._procs_io_bound = int (cpus / 2 )
26+ logging .info (f"_procs_io_bound: { self ._procs_io_bound } " )
2127
22- def read_csv (self , ** args ):
23- spark = self ._session .spark_session
28+ def read_csv (self , ** args ) -> DataFrame :
29+ spark : SparkSession = self ._session .spark_session
2430 return spark .read .csv (** args )
2531
2632 @staticmethod
27- def _extract_casts (dtypes ):
28- casts = {}
33+ def _extract_casts (dtypes : List [Tuple [str , str ]]) -> Dict [str , str ]:
34+ casts : Dict [str , str ] = {}
35+ name : str
36+ dtype : str
2937 for name , dtype in dtypes :
3038 if dtype in ["smallint" , "int" , "bigint" ]:
3139 casts [name ] = "bigint"
@@ -35,7 +43,9 @@ def _extract_casts(dtypes):
3543 return casts
3644
3745 @staticmethod
38- def date2timestamp (dataframe ):
46+ def date2timestamp (dataframe : DataFrame ) -> DataFrame :
47+ name : str
48+ dtype : str
3949 for name , dtype in dataframe .dtypes :
4050 if dtype == "date" :
4151 dataframe = dataframe .withColumn (name , dataframe [name ].cast (TimestampType ()))
@@ -44,19 +54,19 @@ def date2timestamp(dataframe):
4454
4555 def to_redshift (
4656 self ,
47- dataframe ,
48- path ,
49- connection ,
50- schema ,
51- table ,
52- iam_role ,
53- diststyle = "AUTO" ,
57+ dataframe : DataFrame ,
58+ path : str ,
59+ connection : Any ,
60+ schema : str ,
61+ table : str ,
62+ iam_role : str ,
63+ diststyle : str = "AUTO" ,
5464 distkey = None ,
55- sortstyle = "COMPOUND" ,
65+ sortstyle : str = "COMPOUND" ,
5666 sortkey = None ,
57- min_num_partitions = 200 ,
58- mode = "append" ,
59- ):
67+ min_num_partitions : int = 200 ,
68+ mode : str = "append" ,
69+ ) -> None :
6070 """
6171 Load Spark Dataframe as a Table on Amazon Redshift
6272
@@ -78,54 +88,58 @@ def to_redshift(
7888 if path [- 1 ] != "/" :
7989 path += "/"
8090 self ._session .s3 .delete_objects (path = path )
81- spark = self ._session .spark_session
82- casts = Spark ._extract_casts (dataframe .dtypes )
91+ spark : SparkSession = self ._session .spark_session
92+ casts : Dict [ str , str ] = Spark ._extract_casts (dataframe .dtypes )
8393 dataframe = Spark .date2timestamp (dataframe )
8494 dataframe .cache ()
85- num_rows = dataframe .count ()
95+ num_rows : int = dataframe .count ()
8696 logger .info (f"Number of rows: { num_rows } " )
97+ num_partitions : int
8798 if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE :
8899 num_partitions = 1
89100 else :
90- num_slices = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
101+ num_slices : int = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
91102 logger .debug (f"Number of slices on Redshift: { num_slices } " )
92103 num_partitions = num_slices
93104 while num_partitions < min_num_partitions :
94105 num_partitions += num_slices
95106 logger .debug (f"Number of partitions calculated: { num_partitions } " )
96107 spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
97108 session_primitives = self ._session .primitives
109+ par_col_name : str = "aws_data_wrangler_internal_partition_id"
98110
99111 @pandas_udf (returnType = "objects_paths string" , functionType = PandasUDFType .GROUPED_MAP )
100- def write (pandas_dataframe ) :
112+ def write (pandas_dataframe : pd . DataFrame ) -> pd . DataFrame :
101113 # Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for
102114 # a temporary workaround while waiting for Apache Arrow updates
103115 # https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0
104116 os .environ ["ARROW_PRE_0_15_IPC_FORMAT" ] = "1"
105117
106- del pandas_dataframe ["aws_data_wrangler_internal_partition_id" ]
107- paths = session_primitives .session .pandas .to_parquet (dataframe = pandas_dataframe ,
108- path = path ,
109- preserve_index = False ,
110- mode = "append" ,
111- procs_cpu_bound = 1 ,
112- cast_columns = casts )
118+ del pandas_dataframe [par_col_name ]
119+ paths : List [str ] = session_primitives .session .pandas .to_parquet (dataframe = pandas_dataframe ,
120+ path = path ,
121+ preserve_index = False ,
122+ mode = "append" ,
123+ procs_cpu_bound = 1 ,
124+ procs_io_bound = 1 ,
125+ cast_columns = casts )
113126 return pd .DataFrame .from_dict ({"objects_paths" : paths })
114127
115- df_objects_paths = dataframe .repartition (numPartitions = num_partitions ) \
116- .withColumn ("aws_data_wrangler_internal_partition_id" , spark_partition_id ()) \
117- .groupby ("aws_data_wrangler_internal_partition_id" ) \
118- .apply (write )
128+ df_objects_paths = dataframe .repartition (numPartitions = num_partitions ) # type: ignore
129+ df_objects_paths = df_objects_paths .withColumn (par_col_name , spark_partition_id ()) # type: ignore
130+ df_objects_paths = df_objects_paths .groupby (par_col_name ).apply (write ) # type: ignore
119131
120- objects_paths = list (df_objects_paths .toPandas ()["objects_paths" ])
132+ objects_paths : List [ str ] = list (df_objects_paths .toPandas ()["objects_paths" ])
121133 dataframe .unpersist ()
122- num_files_returned = len (objects_paths )
134+ num_files_returned : int = len (objects_paths )
123135 if num_files_returned != num_partitions :
124136 raise MissingBatchDetected (f"{ num_files_returned } files returned. { num_partitions } expected." )
125137 logger .debug (f"List of objects returned: { objects_paths } " )
126138 logger .debug (f"Number of objects returned from UDF: { num_files_returned } " )
127- manifest_path = f"{ path } manifest.json"
128- self ._session .redshift .write_load_manifest (manifest_path = manifest_path , objects_paths = objects_paths )
139+ manifest_path : str = f"{ path } manifest.json"
140+ self ._session .redshift .write_load_manifest (manifest_path = manifest_path ,
141+ objects_paths = objects_paths ,
142+ procs_io_bound = self ._procs_io_bound )
129143 self ._session .redshift .load_table (dataframe = dataframe ,
130144 dataframe_type = "spark" ,
131145 manifest_path = manifest_path ,
0 commit comments