33import logging
44
55import pg8000 # type: ignore
6+ import pyarrow as pa # type: ignore
67
78from awswrangler import data_types
8- from awswrangler .exceptions import (
9- RedshiftLoadError ,
10- InvalidDataframeType ,
11- InvalidRedshiftDiststyle ,
12- InvalidRedshiftDistkey ,
13- InvalidRedshiftSortstyle ,
14- InvalidRedshiftSortkey ,
15- )
9+ from awswrangler .exceptions import (RedshiftLoadError , InvalidDataframeType , InvalidRedshiftDiststyle ,
10+ InvalidRedshiftDistkey , InvalidRedshiftSortstyle , InvalidRedshiftSortkey ,
11+ InvalidRedshiftPrimaryKeys )
1612
1713logger = logging .getLogger (__name__ )
1814
@@ -165,6 +161,7 @@ def load_table(dataframe,
165161 distkey = None ,
166162 sortstyle = "COMPOUND" ,
167163 sortkey = None ,
164+ primary_keys : Optional [List [str ]] = None ,
168165 mode = "append" ,
169166 preserve_index = False ,
170167 cast_columns = None ):
@@ -184,11 +181,14 @@ def load_table(dataframe,
184181 :param distkey: Specifies a column name or positional number for the distribution key
185182 :param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
186183 :param sortkey: List of columns to be sorted
187- :param mode: append or overwrite
184+ :param primary_keys: Primary keys
185+ :param mode: append, overwrite or upsert
188186 :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
189187 :param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"})
190188 :return: None
191189 """
190+ final_table_name : Optional [str ] = None
191+ temp_table_name : Optional [str ] = None
192192 cursor = redshift_conn .cursor ()
193193 if mode == "overwrite" :
194194 Redshift ._create_table (cursor = cursor ,
@@ -200,13 +200,27 @@ def load_table(dataframe,
200200 distkey = distkey ,
201201 sortstyle = sortstyle ,
202202 sortkey = sortkey ,
203+ primary_keys = primary_keys ,
203204 preserve_index = preserve_index ,
204205 cast_columns = cast_columns )
206+ table_name = f"{ schema_name } .{ table_name } "
207+ elif mode == "upsert" :
208+ guid : str = pa .compat .guid ()
209+ temp_table_name = f"temp_redshift_{ guid } "
210+ final_table_name = table_name
211+ table_name = temp_table_name
212+ sql : str = f"CREATE TEMPORARY TABLE { temp_table_name } (LIKE { schema_name } .{ final_table_name } )"
213+ logger .debug (sql )
214+ cursor .execute (sql )
215+ else :
216+ table_name = f"{ schema_name } .{ table_name } "
217+
205218 sql = ("-- AWS DATA WRANGLER\n "
206- f"COPY { schema_name } . { table_name } FROM '{ manifest_path } '\n "
219+ f"COPY { table_name } FROM '{ manifest_path } '\n "
207220 f"IAM_ROLE '{ iam_role } '\n "
208221 "MANIFEST\n "
209222 "FORMAT AS PARQUET" )
223+ logger .debug (sql )
210224 cursor .execute (sql )
211225 cursor .execute ("-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id" )
212226 query_id = cursor .fetchall ()[0 ][0 ]
@@ -219,6 +233,23 @@ def load_table(dataframe,
219233 cursor .close ()
220234 raise RedshiftLoadError (
221235 f"Redshift load rollbacked. { num_files_loaded } files counted. { num_files } expected." )
236+
237+ if (mode == "upsert" ) and (final_table_name is not None ):
238+ if not primary_keys :
239+ primary_keys = Redshift .get_primary_keys (connection = redshift_conn ,
240+ schema = schema_name ,
241+ table = final_table_name )
242+ if not primary_keys :
243+ raise InvalidRedshiftPrimaryKeys ()
244+ equals_clause = f"{ final_table_name } .%s = { temp_table_name } .%s"
245+ join_clause = " AND " .join ([equals_clause % (pk , pk ) for pk in primary_keys ])
246+ sql = f"DELETE FROM { schema_name } .{ final_table_name } USING { temp_table_name } WHERE { join_clause } "
247+ logger .debug (sql )
248+ cursor .execute (sql )
249+ sql = f"INSERT INTO { schema_name } .{ final_table_name } SELECT * FROM { temp_table_name } "
250+ logger .debug (sql )
251+ cursor .execute (sql )
252+
222253 redshift_conn .commit ()
223254 cursor .close ()
224255
@@ -232,6 +263,7 @@ def _create_table(cursor,
232263 distkey = None ,
233264 sortstyle = "COMPOUND" ,
234265 sortkey = None ,
266+ primary_keys : List [str ] = None ,
235267 preserve_index = False ,
236268 cast_columns = None ):
237269 """
@@ -246,6 +278,7 @@ def _create_table(cursor,
246278 :param distkey: Specifies a column name or positional number for the distribution key
247279 :param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
248280 :param sortkey: List of columns to be sorted
281+ :param primary_keys: Primary keys
249282 :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
250283 :param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"})
251284 :return: None
@@ -273,22 +306,43 @@ def _create_table(cursor,
273306 distkey = distkey ,
274307 sortstyle = sortstyle ,
275308 sortkey = sortkey )
276- cols_str = "" .join ([f"{ col [0 ]} { col [1 ]} ,\n " for col in schema ])[:- 2 ]
277- distkey_str = ""
309+ cols_str : str = "" .join ([f"{ col [0 ]} { col [1 ]} ,\n " for col in schema ])[:- 2 ]
310+ primary_keys_str : str = ""
311+ if primary_keys :
312+ primary_keys_str = f",\n PRIMARY KEY ({ ', ' .join (primary_keys )} )"
313+ distkey_str : str = ""
278314 if distkey and diststyle == "KEY" :
279315 distkey_str = f"\n DISTKEY({ distkey } )"
280- sortkey_str = ""
316+ sortkey_str : str = ""
281317 if sortkey :
282318 sortkey_str = f"\n { sortstyle } SORTKEY({ ',' .join (sortkey )} )"
283319 sql = (f"-- AWS DATA WRANGLER\n "
284320 f"CREATE TABLE IF NOT EXISTS { schema_name } .{ table_name } (\n "
285321 f"{ cols_str } "
322+ f"{ primary_keys_str } "
286323 f")\n DISTSTYLE { diststyle } "
287324 f"{ distkey_str } "
288325 f"{ sortkey_str } " )
289326 logger .debug (f"Create table query:\n { sql } " )
290327 cursor .execute (sql )
291328
329+ @staticmethod
330+ def get_primary_keys (connection , schema , table ):
331+ """
332+ Get PKs
333+ :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
334+ :param schema: Schema name
335+ :param table: Redshift table name
336+ :return: PKs list List[str]
337+ """
338+ cursor = connection .cursor ()
339+ cursor .execute (f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{ schema } ' AND tablename = '{ table } '" )
340+ result = cursor .fetchall ()[0 ][0 ]
341+ rfields = result .split ('(' )[1 ].strip (')' ).split (',' )
342+ fields = [field .strip ().strip ('"' ) for field in rfields ]
343+ cursor .close ()
344+ return fields
345+
292346 @staticmethod
293347 def _validate_parameters (schema , diststyle , distkey , sortstyle , sortkey ):
294348 """
@@ -347,8 +401,8 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c
347401 raise InvalidDataframeType (dataframe_type )
348402 return schema_built
349403
350- @ staticmethod
351- def to_parquet ( sql : str ,
404+ def to_parquet ( self ,
405+ sql : str ,
352406 path : str ,
353407 iam_role : str ,
354408 connection : Any ,
@@ -366,8 +420,11 @@ def to_parquet(sql: str,
366420 path = path if path [- 1 ] == "/" else path + "/"
367421 cursor : Any = connection .cursor ()
368422 partition_str : str = ""
423+ manifest_str : str = ""
369424 if partition_cols is not None :
370425 partition_str = f"PARTITION BY ({ ',' .join ([x for x in partition_cols ])} )\n "
426+ else :
427+ manifest_str = "\n manifest"
371428 query : str = f"-- AWS DATA WRANGLER\n " \
372429 f"UNLOAD ('{ sql } ')\n " \
373430 f"TO '{ path } '\n " \
@@ -376,7 +433,8 @@ def to_parquet(sql: str,
376433 f"PARALLEL ON\n " \
377434 f"ENCRYPTED \n " \
378435 f"{ partition_str } " \
379- f"FORMAT PARQUET;"
436+ f"FORMAT PARQUET" \
437+ f"{ manifest_str } ;"
380438 logger .debug (f"query:\n { query } " )
381439 cursor .execute (query )
382440 query = "-- AWS DATA WRANGLER\n SELECT pg_last_query_id() AS query_id"
@@ -391,4 +449,8 @@ def to_parquet(sql: str,
391449 logger .debug (f"paths: { paths } " )
392450 connection .commit ()
393451 cursor .close ()
452+ if manifest_str != "" :
453+ self ._session .s3 .wait_object_exists (path = f"{ path } manifest" )
454+ for p in paths :
455+ self ._session .s3 .wait_object_exists (path = p )
394456 return paths
0 commit comments