Skip to content

Commit f7d8b93

Browse files
authored
feat: Athena - add client_request_token (#2474)
* feat: Athena - add client_request_token Signed-off-by: Anton Kukushkin <[email protected]> * refactor: Remove cache defaults copy-paste Signed-off-by: Anton Kukushkin <[email protected]> * test: Add client_request_token test case Signed-off-by: Anton Kukushkin <[email protected]> * PR feedback Signed-off-by: Anton Kukushkin <[email protected]> * Fix local cache test case Signed-off-by: Anton Kukushkin <[email protected]> * Drop coverage limit slightly Signed-off-by: Anton Kukushkin <[email protected]> --------- Signed-off-by: Anton Kukushkin <[email protected]>
1 parent 69625a9 commit f7d8b93

File tree

6 files changed

+124
-63
lines changed

6 files changed

+124
-63
lines changed

awswrangler/athena/_cache.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import boto3
1010

11-
from awswrangler import _utils
11+
from awswrangler import _utils, typing
1212

1313
if TYPE_CHECKING:
1414
from mypy_boto3_athena.type_defs import QueryExecutionTypeDef
@@ -170,15 +170,23 @@ def _check_for_cached_results(
170170
sql: str,
171171
boto3_session: Optional[boto3.Session],
172172
workgroup: Optional[str],
173-
max_cache_seconds: int,
174-
max_cache_query_inspections: int,
175-
max_remote_cache_entries: int,
173+
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
176174
) -> _CacheInfo:
177175
"""
178176
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
179177
180178
If so, returns a dict with Athena's `query_execution_info` and the data format.
181179
"""
180+
athena_cache_settings = athena_cache_settings or {}
181+
182+
max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0)
183+
max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50)
184+
max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50)
185+
max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100)
186+
max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries)
187+
188+
_cache_manager.max_cache_size = max_local_cache_entries
189+
182190
if max_cache_seconds <= 0:
183191
return _CacheInfo(has_valid_cache=False)
184192
num_executions_inspected: int = 0

awswrangler/athena/_executions.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from awswrangler import _utils, exceptions, typing
1818
from awswrangler._config import apply_configs
1919

20-
from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results
20+
from ._cache import _CacheInfo, _check_for_cached_results
2121
from ._utils import (
2222
_QUERY_FINAL_STATES,
2323
_QUERY_WAIT_POLLING_DELAY,
@@ -41,6 +41,7 @@ def start_query_execution(
4141
params: Union[Dict[str, Any], List[str], None] = None,
4242
paramstyle: Literal["qmark", "named"] = "named",
4343
boto3_session: Optional[boto3.Session] = None,
44+
client_request_token: Optional[str] = None,
4445
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
4546
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
4647
data_source: Optional[str] = None,
@@ -88,6 +89,13 @@ def start_query_execution(
8889
- ``qmark``
8990
boto3_session : boto3.Session(), optional
9091
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
92+
client_request_token : str, optional
93+
A unique case-sensitive string used to ensure the request to create the query is idempotent (executes only once).
94+
If another StartQueryExecution request is received, the same response is returned and another query is not created.
95+
If a parameter has changed, for example, the QueryString , an error is returned.
96+
If you pass the same client_request_token value with different parameters the query fails with error
97+
message "Idempotent parameters do not match". Use this only with ctas_approach=False and unload_approach=False
98+
and disabled cache.
9199
athena_cache_settings: typing.AthenaCacheSettings, optional
92100
Parameters of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections,
93101
max_remote_cache_entries, and max_local_cache_entries.
@@ -125,26 +133,16 @@ def start_query_execution(
125133
sql, execution_params = _apply_formatter(sql, params, paramstyle)
126134
_logger.debug("Executing query:\n%s", sql)
127135

128-
athena_cache_settings = athena_cache_settings if athena_cache_settings else {}
129-
max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0)
130-
max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50)
131-
max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50)
132-
max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100)
133-
134-
max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries)
135-
136-
_cache_manager.max_cache_size = max_local_cache_entries
137-
cache_info: _CacheInfo = _check_for_cached_results(
138-
sql=sql,
139-
boto3_session=boto3_session,
140-
workgroup=workgroup,
141-
max_cache_seconds=max_cache_seconds,
142-
max_cache_query_inspections=max_cache_query_inspections,
143-
max_remote_cache_entries=max_remote_cache_entries,
144-
)
145-
_logger.debug("Cache info:\n%s", cache_info)
136+
if not client_request_token:
137+
cache_info: _CacheInfo = _check_for_cached_results(
138+
sql=sql,
139+
boto3_session=boto3_session,
140+
workgroup=workgroup,
141+
athena_cache_settings=athena_cache_settings,
142+
)
143+
_logger.debug("Cache info:\n%s", cache_info)
146144

147-
if cache_info.has_valid_cache and cache_info.query_execution_id is not None:
145+
if not client_request_token and cache_info.has_valid_cache and cache_info.query_execution_id is not None:
148146
query_execution_id = cache_info.query_execution_id
149147
_logger.debug("Valid cache found. Retrieving...")
150148
else:
@@ -159,6 +157,7 @@ def start_query_execution(
159157
encryption=encryption,
160158
kms_key=kms_key,
161159
execution_params=execution_params,
160+
client_request_token=client_request_token,
162161
boto3_session=boto3_session,
163162
)
164163
if wait:

awswrangler/athena/_read.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def _resolve_query_without_cache_regular(
427427
boto3_session: Optional[boto3.Session],
428428
execution_params: Optional[List[str]] = None,
429429
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
430+
client_request_token: Optional[str] = None,
430431
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
431432
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
432433
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
@@ -442,6 +443,7 @@ def _resolve_query_without_cache_regular(
442443
encryption=encryption,
443444
kms_key=kms_key,
444445
execution_params=execution_params,
446+
client_request_token=client_request_token,
445447
boto3_session=boto3_session,
446448
)
447449
_logger.debug("Query id: %s", query_id)
@@ -490,6 +492,7 @@ def _resolve_query_without_cache(
490492
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
491493
execution_params: Optional[List[str]] = None,
492494
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
495+
client_request_token: Optional[str] = None,
493496
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
494497
"""
495498
Execute a query in Athena and returns results as DataFrame, back to `read_sql_query`.
@@ -570,6 +573,7 @@ def _resolve_query_without_cache(
570573
boto3_session=boto3_session,
571574
execution_params=execution_params,
572575
dtype_backend=dtype_backend,
576+
client_request_token=client_request_token,
573577
)
574578

575579

@@ -776,6 +780,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
776780
keep_files: bool = True,
777781
use_threads: Union[bool, int] = True,
778782
boto3_session: Optional[boto3.Session] = None,
783+
client_request_token: Optional[str] = None,
779784
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
780785
data_source: Optional[str] = None,
781786
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
@@ -938,6 +943,13 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
938943
If integer is provided, specified number is used.
939944
boto3_session : boto3.Session(), optional
940945
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
946+
client_request_token : str, optional
947+
A unique case-sensitive string used to ensure the request to create the query is idempotent (executes only once).
948+
If another StartQueryExecution request is received, the same response is returned and another query is not created.
949+
If a parameter has changed, for example, the QueryString , an error is returned.
950+
If you pass the same client_request_token value with different parameters the query fails with error
951+
message "Idempotent parameters do not match". Use this only with ctas_approach=False and unload_approach=False
952+
and disabled cache.
941953
athena_cache_settings: typing.AthenaCacheSettings, optional
942954
Parameters of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections,
943955
max_remote_cache_entries, and max_local_cache_entries.
@@ -1022,46 +1034,44 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
10221034
raise exceptions.InvalidArgumentCombination("Only one of ctas_approach=True or unload_approach=True is allowed")
10231035
if unload_parameters and unload_parameters.get("file_format") not in (None, "PARQUET"):
10241036
raise exceptions.InvalidArgumentCombination("Only PARQUET file format is supported if unload_approach=True")
1037+
if client_request_token and athena_cache_settings:
1038+
raise exceptions.InvalidArgumentCombination(
1039+
"Only one of `client_request_token` or `athena_cache_settings` is allowed."
1040+
)
1041+
if client_request_token and (ctas_approach or unload_approach):
1042+
raise exceptions.InvalidArgumentCombination(
1043+
"Using `client_request_token` is only allowed when `ctas_approach=False` and `unload_approach=False`."
1044+
)
10251045
chunksize = sys.maxsize if ctas_approach is False and chunksize is True else chunksize
10261046

10271047
# Substitute query parameters if applicable
10281048
sql, execution_params = _apply_formatter(sql, params, paramstyle)
10291049

1030-
athena_cache_settings = athena_cache_settings if athena_cache_settings else {}
1031-
max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0)
1032-
max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50)
1033-
max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50)
1034-
max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100)
1035-
1036-
max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries)
1037-
1038-
_cache_manager.max_cache_size = max_local_cache_entries
1039-
cache_info: _CacheInfo = _check_for_cached_results(
1040-
sql=sql,
1041-
boto3_session=boto3_session,
1042-
workgroup=workgroup,
1043-
max_cache_seconds=max_cache_seconds,
1044-
max_cache_query_inspections=max_cache_query_inspections,
1045-
max_remote_cache_entries=max_remote_cache_entries,
1046-
)
1047-
_logger.debug("Cache info:\n%s", cache_info)
1048-
if cache_info.has_valid_cache is True:
1049-
_logger.debug("Valid cache found. Retrieving...")
1050-
try:
1051-
return _resolve_query_with_cache(
1052-
cache_info=cache_info,
1053-
categories=categories,
1054-
chunksize=chunksize,
1055-
use_threads=use_threads,
1056-
session=boto3_session,
1057-
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
1058-
s3_additional_kwargs=s3_additional_kwargs,
1059-
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
1060-
dtype_backend=dtype_backend,
1061-
)
1062-
except Exception as e: # pylint: disable=broad-except
1063-
_logger.error(e) # if there is anything wrong with the cache, just fallback to the usual path
1064-
_logger.debug("Corrupted cache. Continuing to execute query...")
1050+
if not client_request_token:
1051+
cache_info: _CacheInfo = _check_for_cached_results(
1052+
sql=sql,
1053+
boto3_session=boto3_session,
1054+
workgroup=workgroup,
1055+
athena_cache_settings=athena_cache_settings,
1056+
)
1057+
_logger.debug("Cache info:\n%s", cache_info)
1058+
if cache_info.has_valid_cache is True:
1059+
_logger.debug("Valid cache found. Retrieving...")
1060+
try:
1061+
return _resolve_query_with_cache(
1062+
cache_info=cache_info,
1063+
categories=categories,
1064+
chunksize=chunksize,
1065+
use_threads=use_threads,
1066+
session=boto3_session,
1067+
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
1068+
s3_additional_kwargs=s3_additional_kwargs,
1069+
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
1070+
dtype_backend=dtype_backend,
1071+
)
1072+
except Exception as e: # pylint: disable=broad-except
1073+
_logger.error(e) # if there is anything wrong with the cache, just fallback to the usual path
1074+
_logger.debug("Corrupted cache. Continuing to execute query...")
10651075

10661076
ctas_parameters = ctas_parameters if ctas_parameters else {}
10671077
ctas_database = ctas_parameters.get("database")
@@ -1094,6 +1104,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
10941104
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
10951105
execution_params=execution_params,
10961106
dtype_backend=dtype_backend,
1107+
client_request_token=client_request_token,
10971108
)
10981109

10991110

@@ -1117,6 +1128,7 @@ def read_sql_table(
11171128
keep_files: bool = True,
11181129
use_threads: Union[bool, int] = True,
11191130
boto3_session: Optional[boto3.Session] = None,
1131+
client_request_token: Optional[str] = None,
11201132
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
11211133
data_source: Optional[str] = None,
11221134
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
@@ -1274,6 +1286,13 @@ def read_sql_table(
12741286
If integer is provided, specified number is used.
12751287
boto3_session : boto3.Session(), optional
12761288
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1289+
client_request_token : str, optional
1290+
A unique case-sensitive string used to ensure the request to create the query is idempotent (executes only once).
1291+
If another StartQueryExecution request is received, the same response is returned and another query is not created.
1292+
If a parameter has changed, for example, the QueryString , an error is returned.
1293+
If you pass the same client_request_token value with different parameters the query fails with error
1294+
message "Idempotent parameters do not match". Use this only with ctas_approach=False and unload_approach=False
1295+
and disabled cache.
12771296
athena_cache_settings: typing.AthenaCacheSettings, optional
12781297
Parameters of the Athena cache settings such as max_cache_seconds, max_cache_query_inspections,
12791298
max_remote_cache_entries, and max_local_cache_entries.
@@ -1327,6 +1346,7 @@ def read_sql_table(
13271346
keep_files=keep_files,
13281347
use_threads=use_threads,
13291348
boto3_session=boto3_session,
1349+
client_request_token=client_request_token,
13301350
athena_cache_settings=athena_cache_settings,
13311351
data_source=data_source,
13321352
s3_additional_kwargs=s3_additional_kwargs,

awswrangler/athena/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _start_query_execution(
8686
encryption: Optional[str] = None,
8787
kms_key: Optional[str] = None,
8888
execution_params: Optional[List[str]] = None,
89+
client_request_token: Optional[str] = None,
8990
boto3_session: Optional[boto3.Session] = None,
9091
) -> str:
9192
args: Dict[str, Any] = {"QueryString": sql}
@@ -116,6 +117,9 @@ def _start_query_execution(
116117
if workgroup is not None:
117118
args["WorkGroup"] = workgroup
118119

120+
if client_request_token:
121+
args["ClientRequestToken"] = client_request_token
122+
119123
if execution_params:
120124
args["ExecutionParameters"] = execution_params
121125

tests/unit/test_athena_cache.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def test_local_cache(wr, path, glue_database, glue_table):
175175
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
176176
wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table)
177177

178+
# Set max cache size because it is supposed to be set in the patched method below
179+
wr.athena._cache._cache_manager.max_cache_size = 1
180+
178181
with patch(
179182
"awswrangler.athena._read._check_for_cached_results",
180183
return_value=wr.athena._read._CacheInfo(has_valid_cache=False),
@@ -189,7 +192,7 @@ def test_local_cache(wr, path, glue_database, glue_table):
189192
assert df.shape == df2.shape
190193
assert df.c0.sum() == df2.c0.sum()
191194
first_query_id = df2.query_metadata["QueryExecutionId"]
192-
assert first_query_id in wr.athena._read._cache_manager
195+
assert first_query_id in wr.athena._cache._cache_manager
193196

194197
df3 = wr.athena.read_sql_query(
195198
f"SELECT * FROM {glue_table}",
@@ -202,8 +205,8 @@ def test_local_cache(wr, path, glue_database, glue_table):
202205
assert df.c0.sum() == df3.c0.sum()
203206
second_query_id = df3.query_metadata["QueryExecutionId"]
204207

205-
assert first_query_id not in wr.athena._read._cache_manager
206-
assert second_query_id in wr.athena._read._cache_manager
208+
assert first_query_id not in wr.athena._cache._cache_manager
209+
assert second_query_id in wr.athena._cache._cache_manager
207210

208211

209212
def test_paginated_remote_cache(wr, path, glue_database, glue_table, workgroup1):
@@ -260,3 +263,30 @@ def test_cache_start_query(wr, path, glue_database, glue_table, data_source):
260263
)
261264
internal_start_query.assert_not_called()
262265
assert query_id == query_id_2
266+
267+
268+
def test_start_query_client_request_token(wr, path, glue_database, glue_table):
269+
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
270+
271+
wr.s3.to_parquet(
272+
df=df,
273+
path=path,
274+
dataset=True,
275+
mode="overwrite",
276+
database=glue_database,
277+
table=glue_table,
278+
)
279+
280+
client_request_token = f"token-{glue_database}-{glue_table}-1"
281+
query_id_1 = wr.athena.start_query_execution(
282+
sql=f"SELECT * FROM {glue_table}",
283+
database=glue_database,
284+
client_request_token=client_request_token,
285+
)
286+
query_id_2 = wr.athena.start_query_execution(
287+
sql=f"SELECT * FROM {glue_table}",
288+
database=glue_database,
289+
client_request_token=client_request_token,
290+
)
291+
292+
assert query_id_1 == query_id_2

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ passenv =
1010
AWS_SECRET_ACCESS_KEY
1111
AWS_SESSION_TOKEN
1212
setenv =
13-
COV_FAIL_UNDER = 87.00
13+
COV_FAIL_UNDER = 82.00
1414
allowlist_externals = poetry
1515
commands_pre =
1616
poetry install --no-root --sync --extras "deltalake gremlin mysql opencypher opensearch oracle postgres redshift sparql sqlserver geopandas"

0 commit comments

Comments
 (0)