diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index b2d3f518a..f3297771e 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -40,6 +40,7 @@ def start_query_execution( kms_key: str | None = None, params: dict[str, Any] | list[str] | None = None, paramstyle: Literal["qmark", "named"] = "named", + result_reuse_configuration: dict[str, Any] | None = None, boto3_session: boto3.Session | None = None, client_request_token: str | None = None, athena_cache_settings: typing.AthenaCacheSettings | None = None, @@ -87,6 +88,9 @@ def start_query_execution( - ``named`` - ``qmark`` + result_reuse_configuration + A structure that contains the configuration settings for reusing query results. + See also: https://docs.aws.amazon.com/athena/latest/ug/reusing-query-results.html boto3_session The default boto3 session will be used if **boto3_session** receive ``None``. client_request_token @@ -156,6 +160,7 @@ def start_query_execution( encryption=encryption, kms_key=kms_key, execution_params=execution_params, + result_reuse_configuration=result_reuse_configuration, client_request_token=client_request_token, boto3_session=boto3_session, ) diff --git a/awswrangler/athena/_executions.pyi b/awswrangler/athena/_executions.pyi index 5a394d916..073e63c8d 100644 --- a/awswrangler/athena/_executions.pyi +++ b/awswrangler/athena/_executions.pyi @@ -18,6 +18,7 @@ def start_query_execution( kms_key: str | None = ..., params: dict[str, Any] | list[str] | None = ..., paramstyle: Literal["qmark", "named"] = ..., + result_reuse_configuration: dict[str, Any] | None = ..., boto3_session: boto3.Session | None = ..., athena_cache_settings: typing.AthenaCacheSettings | None = ..., athena_query_wait_polling_delay: float = ..., @@ -35,6 +36,7 @@ def start_query_execution( kms_key: str | None = ..., params: dict[str, Any] | list[str] | None = ..., paramstyle: Literal["qmark", "named"] = ..., + result_reuse_configuration: dict[str, Any] | None = ..., boto3_session: boto3.Session | None = ..., athena_cache_settings: typing.AthenaCacheSettings | None = ..., athena_query_wait_polling_delay: float = ..., @@ -52,6 +54,7 @@ def start_query_execution( kms_key: str | None = ..., params: dict[str, Any] | list[str] | None = ..., paramstyle: Literal["qmark", "named"] = ..., + result_reuse_configuration: dict[str, Any] | None = ..., boto3_session: boto3.Session | None = ..., athena_cache_settings: typing.AthenaCacheSettings | None = ..., athena_query_wait_polling_delay: float = ..., diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 34b088a1d..6d05fa35d 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -427,6 +427,7 @@ def _resolve_query_without_cache_regular( s3_additional_kwargs: dict[str, Any] | None, boto3_session: boto3.Session | None, execution_params: list[str] | None = None, + result_reuse_configuration: dict[str, Any] | None = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", client_request_token: str | None = None, ) -> pd.DataFrame | Iterator[pd.DataFrame]: @@ -444,6 +445,7 @@ def _resolve_query_without_cache_regular( encryption=encryption, kms_key=kms_key, execution_params=execution_params, + result_reuse_configuration=result_reuse_configuration, client_request_token=client_request_token, boto3_session=boto3_session, ) @@ -467,7 +469,7 @@ def _resolve_query_without_cache_regular( ) -def _resolve_query_without_cache( +def _resolve_query_without_cache( # noqa: PLR0913 sql: str, database: str, data_source: str | None, @@ -491,6 +493,7 @@ def _resolve_query_without_cache( boto3_session: boto3.Session | None, pyarrow_additional_kwargs: dict[str, Any] | None = None, execution_params: list[str] | None = None, + result_reuse_configuration: dict[str, Any] | None = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", client_request_token: str | None = None, ) -> pd.DataFrame | Iterator[pd.DataFrame]: @@ -572,6 +575,7 @@ def _resolve_query_without_cache( s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, execution_params=execution_params, + result_reuse_configuration=result_reuse_configuration, dtype_backend=dtype_backend, client_request_token=client_request_token, ) @@ -785,6 +789,7 @@ def read_sql_query( athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, params: dict[str, Any] | list[str] | None = None, paramstyle: Literal["qmark", "named"] = "named", + result_reuse_configuration: dict[str, Any] | None = None, dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable", s3_additional_kwargs: dict[str, Any] | None = None, pyarrow_additional_kwargs: dict[str, Any] | None = None, @@ -980,6 +985,10 @@ def read_sql_query( - ``named`` - ``qmark`` + result_reuse_configuration + A structure that contains the configuration settings for reusing query results. + This parameter is only valid when both `ctas_approach` and `unload_approach` are set to `False`. + See also: https://docs.aws.amazon.com/athena/latest/ug/reusing-query-results.html dtype_backend Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays, nullable dtypes are used for all dtypes that have a nullable implementation when @@ -1040,6 +1049,10 @@ def read_sql_query( raise exceptions.InvalidArgumentCombination( "Using `client_request_token` is only allowed when `ctas_approach=False` and `unload_approach=False`." ) + if result_reuse_configuration and (ctas_approach or unload_approach): + raise exceptions.InvalidArgumentCombination( + "Using `result_reuse_configuration` is only allowed when `ctas_approach=False` and `unload_approach=False`." + ) chunksize = sys.maxsize if ctas_approach is False and chunksize is True else chunksize # Substitute query parameters if applicable @@ -1104,6 +1117,7 @@ def read_sql_query( boto3_session=boto3_session, pyarrow_additional_kwargs=pyarrow_additional_kwargs, execution_params=execution_params, + result_reuse_configuration=result_reuse_configuration, dtype_backend=dtype_backend, client_request_token=client_request_token, ) diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 7a0e74aac..7ff7589cc 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -86,6 +86,7 @@ def _start_query_execution( encryption: str | None = None, kms_key: str | None = None, execution_params: list[str] | None = None, + result_reuse_configuration: dict[str, Any] | None = None, client_request_token: str | None = None, boto3_session: boto3.Session | None = None, ) -> str: @@ -123,6 +124,9 @@ def _start_query_execution( if execution_params: args["ExecutionParameters"] = execution_params + if result_reuse_configuration: + args["ResultReuseConfiguration"] = result_reuse_configuration + client_athena = _utils.client(service_name="athena", session=boto3_session) _logger.debug("Starting query execution with args: \n%s", pprint.pformat(args)) response = _utils.try_it( diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index d747ae001..47e444179 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -32,6 +32,86 @@ pytestmark = pytest.mark.distributed +def test_start_query_execution_with_result_reuse_configuration(path, glue_database, glue_table): + df = pd.DataFrame({"c0": [0, 1], "c1": ["foo", "bar"]}) + wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table, + mode="overwrite", + ) + + sql = f"select * from {glue_table}" + result_reuse_configuration = {"ResultReuseByAgeConfiguration": {"Enabled": True, "MaxAgeInMinutes": 1}} + query_execution_result1 = wr.athena.start_query_execution( + sql=sql, database=glue_database, result_reuse_configuration=result_reuse_configuration, wait=True + ) + assert query_execution_result1["Query"] == sql + assert query_execution_result1["ResultReuseConfiguration"] == result_reuse_configuration + assert not query_execution_result1["Statistics"]["ResultReuseInformation"]["ReusedPreviousResult"] + + query_execution_result2 = wr.athena.start_query_execution( + sql=sql, database=glue_database, result_reuse_configuration=result_reuse_configuration, wait=True + ) + assert query_execution_result2["Query"] == sql + assert query_execution_result2["ResultReuseConfiguration"] == result_reuse_configuration + assert query_execution_result2["Statistics"]["ResultReuseInformation"]["ReusedPreviousResult"] + + +def test_read_sql_query_with_result_reuse_configuration(path, glue_database, glue_table): + df = pd.DataFrame({"c0": [0, 1], "c1": ["foo", "bar"]}) + wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table, + mode="overwrite", + ) + + sql = f"select * from {glue_table}" + result_reuse_configuration = {"ResultReuseByAgeConfiguration": {"Enabled": True, "MaxAgeInMinutes": 1}} + df1 = wr.athena.read_sql_query( + sql=sql, + database=glue_database, + ctas_approach=False, + unload_approach=False, + result_reuse_configuration=result_reuse_configuration, + ) + df2 = wr.athena.read_sql_query( + sql=sql, + database=glue_database, + ctas_approach=False, + unload_approach=False, + result_reuse_configuration=result_reuse_configuration, + ) + assert pandas_equals(df1, df2) + assert not df1.query_metadata["Statistics"]["ResultReuseInformation"]["ReusedPreviousResult"] + assert df2.query_metadata["Statistics"]["ResultReuseInformation"]["ReusedPreviousResult"] + + +def test_read_sql_query_with_result_reuse_configuration_error(glue_database): + # default behavior: ctas_approach is True and unload_approach is False + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.athena.read_sql_query( + sql="select 1", + database=glue_database, + result_reuse_configuration={"ResultReuseByAgeConfiguration": {"Enabled": True, "MaxAgeInMinutes": 1}}, + ) + + # ctas_approach is False and default unload_approach is False + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.athena.read_sql_query( + sql="select 1", + database=glue_database, + ctas_approach=False, + unload_approach=True, + result_reuse_configuration={"ResultReuseByAgeConfiguration": {"Enabled": True, "MaxAgeInMinutes": 1}}, + ) + + def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, glue_ctas_database, kms_key): df = get_df_list() columns_types, partitions_types = wr.catalog.extract_athena_types(df=df, partition_cols=["par0", "par1"])