Skip to content

Commit 3b5e2cc

Browse files
aeeladawyjaidisido
andauthored
Feature/add use column names to redshift copy (#1437)
* Add use_column_names to copy/_copy/copy_from_files * Add copy_upsert_with_column_names test * Remove sql line * Fix duplicate definition * Linting * format black * disable toomanylocals for copy( * fix pydocstyle * Blank line contains whitespace * Minor refactoring * Minor sql query fix Co-authored-by: Abdel Jaidi <[email protected]>
1 parent 7391d73 commit 3b5e2cc

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed

awswrangler/redshift.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _copy(
131131
schema: Optional[str] = None,
132132
manifest: Optional[bool] = False,
133133
sql_copy_extra_params: Optional[List[str]] = None,
134+
column_names: Optional[List[str]] = None,
134135
) -> None:
135136
if schema is None:
136137
table_name: str = f'"{table}"'
@@ -145,7 +146,9 @@ def _copy(
145146
boto3_session=boto3_session,
146147
)
147148
ser_json_str: str = " SERIALIZETOJSON" if serialize_to_json else ""
148-
sql: str = f"COPY {table_name}\nFROM '{path}' {auth_str}\nFORMAT AS PARQUET{ser_json_str}"
149+
column_names_str: str = f"({','.join(column_names)})" if column_names else ""
150+
sql = f"COPY {table_name} {column_names_str}\nFROM '{path}' {auth_str}\nFORMAT AS PARQUET{ser_json_str}"
151+
149152
if manifest:
150153
sql += "\nMANIFEST"
151154
if sql_copy_extra_params:
@@ -1250,6 +1253,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
12501253
boto3_session: Optional[boto3.Session] = None,
12511254
s3_additional_kwargs: Optional[Dict[str, str]] = None,
12521255
precombine_key: Optional[str] = None,
1256+
column_names: Optional[List[str]] = None,
12531257
) -> None:
12541258
"""Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).
12551259
@@ -1352,6 +1356,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
13521356
When there is a primary_key match during upsert, this column will change the upsert method,
13531357
comparing the values of the specified column from source and target, and keeping the
13541358
larger of the two. Will only work when mode = upsert.
1359+
column_names: List[str], optional
1360+
List of column names to map source data fields to the target columns.
13551361
13561362
Returns
13571363
-------
@@ -1416,6 +1422,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
14161422
serialize_to_json=serialize_to_json,
14171423
sql_copy_extra_params=sql_copy_extra_params,
14181424
manifest=manifest,
1425+
column_names=column_names,
14191426
)
14201427
if table != created_table: # upsert
14211428
_upsert(
@@ -1425,6 +1432,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
14251432
temp_table=created_table,
14261433
primary_keys=primary_keys,
14271434
precombine_key=precombine_key,
1435+
column_names=column_names,
14281436
)
14291437
if commit_transaction:
14301438
con.commit()
@@ -1436,7 +1444,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
14361444
con.autocommit = autocommit_temp
14371445

14381446

1439-
def copy( # pylint: disable=too-many-arguments
1447+
def copy( # pylint: disable=too-many-arguments,too-many-locals
14401448
df: pd.DataFrame,
14411449
path: str,
14421450
con: redshift_connector.Connection,
@@ -1466,6 +1474,7 @@ def copy( # pylint: disable=too-many-arguments
14661474
s3_additional_kwargs: Optional[Dict[str, str]] = None,
14671475
max_rows_by_file: Optional[int] = 10_000_000,
14681476
precombine_key: Optional[str] = None,
1477+
use_column_names: bool = False,
14691478
) -> None:
14701479
"""Load Pandas DataFrame as a Table on Amazon Redshift using parquet files on S3 as stage.
14711480
@@ -1568,6 +1577,10 @@ def copy( # pylint: disable=too-many-arguments
15681577
When there is a primary_key match during upsert, this column will change the upsert method,
15691578
comparing the values of the specified column from source and target, and keeping the
15701579
larger of the two. Will only work when mode = upsert.
1580+
use_column_names: bool
1581+
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
1582+
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
1583+
inserted into the database columns `col1` and `col3`.
15711584
15721585
Returns
15731586
-------
@@ -1592,6 +1605,7 @@ def copy( # pylint: disable=too-many-arguments
15921605
"""
15931606
path = path[:-1] if path.endswith("*") else path
15941607
path = path if path.endswith("/") else f"{path}/"
1608+
column_names = [f'"{column}"' for column in df.columns] if use_column_names else []
15951609
session: boto3.Session = _utils.ensure_session(session=boto3_session)
15961610
if s3.list_objects(path=path, boto3_session=session, s3_additional_kwargs=s3_additional_kwargs):
15971611
raise exceptions.InvalidArgument(
@@ -1636,6 +1650,7 @@ def copy( # pylint: disable=too-many-arguments
16361650
s3_additional_kwargs=s3_additional_kwargs,
16371651
sql_copy_extra_params=sql_copy_extra_params,
16381652
precombine_key=precombine_key,
1653+
column_names=column_names,
16391654
)
16401655
finally:
16411656
if keep_files is False:

tests/test_redshift.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,3 +1104,84 @@ def test_to_sql_multi_transaction(redshift_table, redshift_con):
11041104
df3 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table} ORDER BY id", con=redshift_con)
11051105
assert len(df.index) + len(df2.index) == len(df3.index)
11061106
assert len(df.columns) == len(df3.columns)
1107+
1108+
1109+
def test_copy_upsert_with_column_names(path, redshift_table, redshift_con, databases_parameters):
1110+
df = pd.DataFrame({"id": list((range(1_000))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(1_000)])})
1111+
df3 = pd.DataFrame(
1112+
{"id": list((range(1_000, 1_500))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(500)])}
1113+
)
1114+
1115+
# CREATE
1116+
path = f"{path}upsert/test_redshift_copy_upsert_with_column_names/"
1117+
wr.redshift.copy(
1118+
df=df,
1119+
path=path,
1120+
con=redshift_con,
1121+
schema="public",
1122+
table=redshift_table,
1123+
mode="overwrite",
1124+
index=False,
1125+
primary_keys=["id"],
1126+
iam_role=databases_parameters["redshift"]["role"],
1127+
use_column_names=True,
1128+
)
1129+
path = f"{path}upsert/test_redshift_copy_upsert_with_column_names2/"
1130+
df2 = wr.redshift.unload(
1131+
sql=f"SELECT * FROM public.{redshift_table}",
1132+
con=redshift_con,
1133+
iam_role=databases_parameters["redshift"]["role"],
1134+
path=path,
1135+
keep_files=False,
1136+
)
1137+
assert len(df.index) == len(df2.index)
1138+
assert len(df.columns) == len(df2.columns)
1139+
1140+
# UPSERT
1141+
path = f"{path}upsert/test_redshift_copy_upsert_with_column_names3/"
1142+
wr.redshift.copy(
1143+
df=df3,
1144+
path=path,
1145+
con=redshift_con,
1146+
schema="public",
1147+
table=redshift_table,
1148+
mode="upsert",
1149+
index=False,
1150+
primary_keys=["id"],
1151+
iam_role=databases_parameters["redshift"]["role"],
1152+
use_column_names=True,
1153+
)
1154+
path = f"{path}upsert/test_redshift_copy_upsert_with_column_names4/"
1155+
df4 = wr.redshift.unload(
1156+
sql=f"SELECT * FROM public.{redshift_table}",
1157+
con=redshift_con,
1158+
iam_role=databases_parameters["redshift"]["role"],
1159+
path=path,
1160+
keep_files=False,
1161+
)
1162+
assert len(df.index) + len(df3.index) == len(df4.index)
1163+
assert len(df.columns) == len(df4.columns)
1164+
1165+
# UPSERT 2 + lock
1166+
wr.redshift.copy(
1167+
df=df3,
1168+
path=path,
1169+
con=redshift_con,
1170+
schema="public",
1171+
table=redshift_table,
1172+
mode="upsert",
1173+
index=False,
1174+
iam_role=databases_parameters["redshift"]["role"],
1175+
lock=True,
1176+
use_column_names=True,
1177+
)
1178+
path = f"{path}upsert/test_redshift_copy_upsert_with_column_names4/"
1179+
df4 = wr.redshift.unload(
1180+
sql=f"SELECT * FROM public.{redshift_table}",
1181+
con=redshift_con,
1182+
iam_role=databases_parameters["redshift"]["role"],
1183+
path=path,
1184+
keep_files=False,
1185+
)
1186+
assert len(df.index) + len(df3.index) == len(df4.index)
1187+
assert len(df.columns) == len(df4.columns)

0 commit comments

Comments
 (0)