diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index 7d002c52f..42df19ef8 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -93,7 +93,7 @@ def connect( write_timeout: int | None = None, connect_timeout: int = 10, cursorclass: type["Cursor"] | None = None, -) -> "pymysql.connections.Connection": # type: ignore[type-arg] +) -> "pymysql.connections.Connection": """Return a pymysql connection from a Glue Catalog Connection or Secrets Manager. https://pymysql.readthedocs.io @@ -231,7 +231,7 @@ def read_sql_query( @_utils.check_optional_dependency(pymysql, "pymysql") def read_sql_query( sql: str, - con: "pymysql.connections.Connection", # type: ignore[type-arg] + con: "pymysql.connections.Connection", index_col: str | list[str] | None = None, params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = None, chunksize: int | None = None, @@ -351,7 +351,7 @@ def read_sql_table( @_utils.check_optional_dependency(pymysql, "pymysql") def read_sql_table( table: str, - con: "pymysql.connections.Connection", # type: ignore[type-arg] + con: "pymysql.connections.Connection", schema: str | None = None, index_col: str | list[str] | None = None, params: list[Any] | tuple[Any, ...] | dict[Any, Any] | None = None, @@ -439,7 +439,7 @@ def read_sql_table( @apply_configs def to_sql( df: pd.DataFrame, - con: "pymysql.connections.Connection", # type: ignore[type-arg] + con: "pymysql.connections.Connection", table: str, schema: str, mode: _ToSqlModeLiteral = "append", diff --git a/awswrangler/redshift/_read.py b/awswrangler/redshift/_read.py index de71bda98..57d4190ec 100644 --- a/awswrangler/redshift/_read.py +++ b/awswrangler/redshift/_read.py @@ -241,6 +241,7 @@ def unload_to_files( kms_key_id: str | None = None, manifest: bool = False, partition_cols: list[str] | None = None, + cleanpath: bool = False, boto3_session: boto3.Session | None = None, ) -> None: """Unload Parquet files on s3 from a Redshift query result (Through the UNLOAD command). @@ -294,6 +295,21 @@ def unload_to_files( Unload a manifest file on S3. partition_cols Specifies the partition keys for the unload operation. + cleanpath + Use CLEANPATH instead of ALLOWOVERWRITE. When True, uses CLEANPATH to remove existing files + located in the Amazon S3 path before unloading files. When False (default), uses ALLOWOVERWRITE + to overwrite existing files, including the manifest file. These options are mutually exclusive. + + ALLOWOVERWRITE: By default, UNLOAD fails if it finds files that it would possibly overwrite. + If ALLOWOVERWRITE is specified, UNLOAD overwrites existing files, including the manifest file. + + CLEANPATH: Removes existing files located in the Amazon S3 path specified in the TO clause + before unloading files to the specified location. If you include the PARTITION BY clause, + existing files are removed only from the partition folders to receive new files generated + by the UNLOAD operation. You must have the s3:DeleteObject permission on the Amazon S3 bucket. + Files removed using CLEANPATH are permanently deleted and can't be recovered. + + For more information, see: https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html boto3_session The default boto3 session will be used if **boto3_session** is ``None``. @@ -307,6 +323,15 @@ def unload_to_files( ... con=con, ... iam_role="arn:aws:iam::XXX:role/XXX" ... ) + >>> # Using CLEANPATH instead of ALLOWOVERWRITE + >>> with wr.redshift.connect("MY_GLUE_CONNECTION") as con: + ... wr.redshift.unload_to_files( + ... sql="SELECT * FROM public.mytable", + ... path="s3://bucket/extracted_parquet_files/", + ... con=con, + ... iam_role="arn:aws:iam::XXX:role/XXX", + ... cleanpath=True + ... ) """ @@ -339,11 +364,13 @@ def unload_to_files( # Escape quotation marks in SQL sql = sql.replace("'", "''") + overwrite_str: str = "CLEANPATH" if cleanpath else "ALLOWOVERWRITE" + unload_sql = ( f"UNLOAD ('{sql}')\n" f"TO '{path}'\n" f"{auth_str}" - "ALLOWOVERWRITE\n" + f"{overwrite_str}\n" f"{parallel_str}\n" f"FORMAT {format_str}\n" "ENCRYPTED" @@ -376,6 +403,7 @@ def unload( chunked: bool | int = False, keep_files: bool = False, parallel: bool = True, + cleanpath: bool = False, use_threads: bool | int = True, boto3_session: boto3.Session | None = None, s3_additional_kwargs: dict[str, str] | None = None, @@ -452,6 +480,21 @@ def unload( By default, UNLOAD writes data in parallel to multiple files, according to the number of slices in the cluster. If parallel is False, UNLOAD writes to one or more data files serially, sorted absolutely according to the ORDER BY clause, if one is used. + cleanpath + Use CLEANPATH instead of ALLOWOVERWRITE. When True, uses CLEANPATH to remove existing files + located in the Amazon S3 path before unloading files. When False (default), uses ALLOWOVERWRITE + to overwrite existing files, including the manifest file. These options are mutually exclusive. + + ALLOWOVERWRITE: By default, UNLOAD fails if it finds files that it would possibly overwrite. + If ALLOWOVERWRITE is specified, UNLOAD overwrites existing files, including the manifest file. + + CLEANPATH: Removes existing files located in the Amazon S3 path specified in the TO clause + before unloading files to the specified location. If you include the PARTITION BY clause, + existing files are removed only from the partition folders to receive new files generated + by the UNLOAD operation. You must have the s3:DeleteObject permission on the Amazon S3 bucket. + Files removed using CLEANPATH are permanently deleted and can't be recovered. + + For more information, see: https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html dtype_backend Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays, nullable dtypes are used for all dtypes that have a nullable implementation when @@ -489,6 +532,15 @@ def unload( ... con=con, ... iam_role="arn:aws:iam::XXX:role/XXX" ... ) + >>> # Using CLEANPATH instead of ALLOWOVERWRITE + >>> with wr.redshift.connect("MY_GLUE_CONNECTION") as con: + ... df = wr.redshift.unload( + ... sql="SELECT * FROM public.mytable", + ... path="s3://bucket/extracted_parquet_files/", + ... con=con, + ... iam_role="arn:aws:iam::XXX:role/XXX", + ... cleanpath=True + ... ) """ path = path if path.endswith("/") else f"{path}/" @@ -505,6 +557,7 @@ def unload( kms_key_id=kms_key_id, manifest=False, parallel=parallel, + cleanpath=cleanpath, boto3_session=boto3_session, ) if chunked is False: diff --git a/tests/unit/test_redshift.py b/tests/unit/test_redshift.py index 6dfb75112..d387c5973 100644 --- a/tests/unit/test_redshift.py +++ b/tests/unit/test_redshift.py @@ -1428,6 +1428,31 @@ def test_unload_escape_quotation_marks( assert len(df2) == 1 +@pytest.mark.parametrize("cleanpath", [False, True]) +def test_unload_cleanpath( + path: str, + redshift_table: str, + redshift_con: redshift_connector.Connection, + databases_parameters: dict[str, Any], + cleanpath: bool, +) -> None: + df = pd.DataFrame({"id": [1, 2], "name": ["foo", "bar"]}) + schema = "public" + + wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, mode="overwrite", index=False) + + df2 = wr.redshift.unload( + sql=f"SELECT * FROM {schema}.{redshift_table}", + con=redshift_con, + iam_role=databases_parameters["redshift"]["role"], + path=path, + keep_files=False, + cleanpath=cleanpath, + ) + assert len(df2.index) == 2 + assert len(df2.columns) == 2 + + @pytest.mark.parametrize( "mode,overwrite_method", [