@@ -71,7 +71,7 @@ def to_redshift(
7171
7272 :param dataframe: Pandas Dataframe
7373 :param path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
74- :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
74+ :param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
7575 :param schema: The Redshift Schema for the table
7676 :param table: The name of the desired Redshift table
7777 :param iam_role: AWS IAM role with the related permissions
@@ -93,68 +93,83 @@ def to_redshift(
9393 dataframe .cache ()
9494 num_rows : int = dataframe .count ()
9595 logger .info (f"Number of rows: { num_rows } " )
96- num_partitions : int
97- if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE :
98- num_partitions = 1
99- else :
100- num_slices : int = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
101- logger .debug (f"Number of slices on Redshift: { num_slices } " )
102- num_partitions = num_slices
103- while num_partitions < min_num_partitions :
104- num_partitions += num_slices
105- logger .debug (f"Number of partitions calculated: { num_partitions } " )
106- spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
107- session_primitives = self ._session .primitives
108- par_col_name : str = "aws_data_wrangler_internal_partition_id"
10996
110- @pandas_udf (returnType = "objects_paths string" , functionType = PandasUDFType .GROUPED_MAP )
111- def write (pandas_dataframe : pd .DataFrame ) -> pd .DataFrame :
112- # Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for
113- # a temporary workaround while waiting for Apache Arrow updates
114- # https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0
115- os .environ ["ARROW_PRE_0_15_IPC_FORMAT" ] = "1"
97+ generated_conn : bool = False
98+ if type (connection ) == str :
99+ logger .debug ("Glue connection (str) provided." )
100+ connection = self ._session .glue .get_connection (name = connection )
101+ generated_conn = True
116102
117- del pandas_dataframe [par_col_name ]
118- paths : List [str ] = session_primitives .session .pandas .to_parquet (dataframe = pandas_dataframe ,
119- path = path ,
120- preserve_index = False ,
121- mode = "append" ,
122- procs_cpu_bound = 1 ,
123- procs_io_bound = 1 ,
124- cast_columns = casts )
125- return pd .DataFrame .from_dict ({"objects_paths" : paths })
103+ try :
104+ num_partitions : int
105+ if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE :
106+ num_partitions = 1
107+ else :
108+ num_slices : int = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
109+ logger .debug (f"Number of slices on Redshift: { num_slices } " )
110+ num_partitions = num_slices
111+ while num_partitions < min_num_partitions :
112+ num_partitions += num_slices
113+ logger .debug (f"Number of partitions calculated: { num_partitions } " )
114+ spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
115+ session_primitives = self ._session .primitives
116+ par_col_name : str = "aws_data_wrangler_internal_partition_id"
126117
127- df_objects_paths : DataFrame = dataframe .repartition (numPartitions = num_partitions ) # type: ignore
128- df_objects_paths : DataFrame = df_objects_paths .withColumn (par_col_name , spark_partition_id ()) # type: ignore
129- df_objects_paths : DataFrame = df_objects_paths .groupby (par_col_name ).apply (write ) # type: ignore
118+ @pandas_udf (returnType = "objects_paths string" , functionType = PandasUDFType .GROUPED_MAP )
119+ def write (pandas_dataframe : pd .DataFrame ) -> pd .DataFrame :
120+ # Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for
121+ # a temporary workaround while waiting for Apache Arrow updates
122+ # https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0
123+ os .environ ["ARROW_PRE_0_15_IPC_FORMAT" ] = "1"
130124
131- objects_paths : List [str ] = list (df_objects_paths .toPandas ()["objects_paths" ])
132- dataframe .unpersist ()
133- num_files_returned : int = len (objects_paths )
134- if num_files_returned != num_partitions :
135- raise MissingBatchDetected (f"{ num_files_returned } files returned. { num_partitions } expected." )
136- logger .debug (f"List of objects returned: { objects_paths } " )
137- logger .debug (f"Number of objects returned from UDF: { num_files_returned } " )
138- manifest_path : str = f"{ path } manifest.json"
139- self ._session .redshift .write_load_manifest (manifest_path = manifest_path ,
140- objects_paths = objects_paths ,
141- procs_io_bound = self ._procs_io_bound )
142- self ._session .redshift .load_table (dataframe = dataframe ,
143- dataframe_type = "spark" ,
144- manifest_path = manifest_path ,
145- schema_name = schema ,
146- table_name = table ,
147- redshift_conn = connection ,
148- preserve_index = False ,
149- num_files = num_partitions ,
150- iam_role = iam_role ,
151- diststyle = diststyle ,
152- distkey = distkey ,
153- sortstyle = sortstyle ,
154- sortkey = sortkey ,
155- mode = mode ,
156- cast_columns = casts )
157- self ._session .s3 .delete_objects (path = path , procs_io_bound = self ._procs_io_bound )
125+ del pandas_dataframe [par_col_name ]
126+ paths : List [str ] = session_primitives .session .pandas .to_parquet (dataframe = pandas_dataframe ,
127+ path = path ,
128+ preserve_index = False ,
129+ mode = "append" ,
130+ procs_cpu_bound = 1 ,
131+ procs_io_bound = 1 ,
132+ cast_columns = casts )
133+ return pd .DataFrame .from_dict ({"objects_paths" : paths })
134+
135+ df_objects_paths : DataFrame = dataframe .repartition (numPartitions = num_partitions ) # type: ignore
136+ df_objects_paths = df_objects_paths .withColumn (par_col_name , spark_partition_id ()) # type: ignore
137+ df_objects_paths = df_objects_paths .groupby (par_col_name ).apply (write ) # type: ignore
138+
139+ objects_paths : List [str ] = list (df_objects_paths .toPandas ()["objects_paths" ])
140+ dataframe .unpersist ()
141+ num_files_returned : int = len (objects_paths )
142+ if num_files_returned != num_partitions :
143+ raise MissingBatchDetected (f"{ num_files_returned } files returned. { num_partitions } expected." )
144+ logger .debug (f"List of objects returned: { objects_paths } " )
145+ logger .debug (f"Number of objects returned from UDF: { num_files_returned } " )
146+ manifest_path : str = f"{ path } manifest.json"
147+ self ._session .redshift .write_load_manifest (manifest_path = manifest_path ,
148+ objects_paths = objects_paths ,
149+ procs_io_bound = self ._procs_io_bound )
150+ self ._session .redshift .load_table (dataframe = dataframe ,
151+ dataframe_type = "spark" ,
152+ manifest_path = manifest_path ,
153+ schema_name = schema ,
154+ table_name = table ,
155+ redshift_conn = connection ,
156+ preserve_index = False ,
157+ num_files = num_partitions ,
158+ iam_role = iam_role ,
159+ diststyle = diststyle ,
160+ distkey = distkey ,
161+ sortstyle = sortstyle ,
162+ sortkey = sortkey ,
163+ mode = mode ,
164+ cast_columns = casts )
165+ self ._session .s3 .delete_objects (path = path , procs_io_bound = self ._procs_io_bound )
166+ except Exception as ex :
167+ connection .rollback ()
168+ if generated_conn is True :
169+ connection .close ()
170+ raise ex
171+ if generated_conn is True :
172+ connection .close ()
158173
159174 def create_glue_table (self ,
160175 database ,
0 commit comments