Skip to content

Commit 89449cb

Browse files
committed
feat(athena): support named & qmark parameters; use generators; update docstring
1 parent b8a607f commit 89449cb

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

awswrangler/athena/_executions.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,22 +184,24 @@ def start_query_executions(
184184
params: dict[str, typing.Any] | list[str] | None = None,
185185
paramstyle: Literal["qmark", "named"] = "named",
186186
boto3_session: boto3.Session | None = None,
187-
client_request_token: str | list[str] | None = None,
187+
client_request_token: str | list[list[str]] | None = None,
188188
athena_cache_settings: typing.AthenaCacheSettings | None = None,
189-
athena_query_wait_polling_delay: float = 1.0,
189+
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
190190
data_source: str | None = None,
191191
wait: bool = False,
192192
check_workgroup: bool = True,
193193
enforce_workgroup: bool = False,
194194
as_iterator: bool = False,
195-
use_threads: bool | int = False,
195+
use_threads: bool | int = False
196196
) -> list[str] | list[dict[str, typing.Any]]:
197197
"""
198198
Start multiple SQL queries against Amazon Athena.
199199
200-
Each query can optionally use Athena's result cache and idempotent request tokens.
201-
Submissions can be sequential or parallel, and each query can be waited on
202-
individually (inside its submission thread) if ``wait=True``.
200+
This is the multi-query counterpart to ``start_query_execution``. It supports
201+
per-query caching and idempotent client request tokens, optional workgroup
202+
validation/enforcement, sequential or thread-pooled parallel dispatch, and
203+
either eager (list) or lazy (iterator) consumption. If ``wait=True``, each
204+
query may be awaited to completion within its submission thread.
203205
204206
Parameters
205207
----------
@@ -255,11 +257,20 @@ def start_query_executions(
255257
raise ValueError("Length of client_request_token list must match number of queries in sqls")
256258
tokens = client_request_token
257259
elif isinstance(client_request_token, str):
258-
tokens = [f"{client_request_token}-{i}" for i in range(len(sqls))]
260+
tokens = (f"{client_request_token}-{i}" for i in range(len(sqls)))
259261
else:
260262
tokens = [None] * len(sqls)
261263

262-
formatted_queries = list(map(lambda q: _apply_formatter(q, params, paramstyle), sqls))
264+
if paramstyle == "named":
265+
formatted_queries = (_apply_formatter(q, params, "named") for q in sqls)
266+
elif paramstyle == "qmark":
267+
_params_list = params or [None] * len(sqls)
268+
formatted_queries = (
269+
_apply_formatter(q, query_params, "qmark")
270+
for q, query_params in zip(sqls, _params_list)
271+
)
272+
else:
273+
raise ValueError("paramstyle must be 'named' or 'qmark'")
263274

264275
if check_workgroup:
265276
wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup)
@@ -273,6 +284,7 @@ def start_query_executions(
273284

274285
def _submit(item: tuple[tuple[str, list[str] | None], str | None]):
275286
(q, execution_params), token = item
287+
_logger.debug("Executing query:\n%s", q)
276288

277289
if token is None and athena_cache_settings is not None:
278290
cache_info = _check_for_cached_results(
@@ -281,7 +293,9 @@ def _submit(item: tuple[tuple[str, list[str] | None], str | None]):
281293
workgroup=workgroup,
282294
athena_cache_settings=athena_cache_settings,
283295
)
296+
_logger.debug("Cache info:\n%s", cache_info)
284297
if cache_info.has_valid_cache and cache_info.query_execution_id is not None:
298+
_logger.debug("Valid cache found. Retrieving...")
285299
return (
286300
wait_query(
287301
query_execution_id=cache_info.query_execution_id,
@@ -315,17 +329,28 @@ def _submit(item: tuple[tuple[str, list[str] | None], str | None]):
315329

316330
return qid
317331

318-
items = list(zip(formatted_queries, tokens))
332+
items = zip(formatted_queries, tokens)
319333

320334
if use_threads is False:
321335
results = map(_submit, items)
322-
else:
323-
max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads)
324-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
325-
results = executor.map(_submit, items)
336+
return results if as_iterator else list(results)
326337

327-
return results if as_iterator else list(results)
338+
max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads)
328339

340+
if as_iterator:
341+
executor = ThreadPoolExecutor(max_workers=max_workers)
342+
it = executor.map(_submit, items)
343+
344+
def _iter():
345+
try:
346+
yield from it
347+
finally:
348+
executor.shutdown(wait=True)
349+
350+
return _iter()
351+
else:
352+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
353+
return list(executor.map(_submit, items))
329354

330355
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None:
331356
"""Stop a query execution.

awswrangler/athena/_executions.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def start_query_executions(
6969
params: dict[str, Any] | list[str] | None = ...,
7070
paramstyle: Literal["qmark", "named"] = ...,
7171
boto3_session: boto3.Session | None = ...,
72-
client_request_token: str | list[str] | None = ...,
72+
client_request_token: str | list[list[str]] | None = ...,
7373
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
7474
athena_query_wait_polling_delay: float = ...,
7575
data_source: str | None = ...,
@@ -91,7 +91,7 @@ def start_query_executions(
9191
params: dict[str, Any] | list[str] | None = ...,
9292
paramstyle: Literal["qmark", "named"] = ...,
9393
boto3_session: boto3.Session | None = ...,
94-
client_request_token: str | list[str] | None = ...,
94+
client_request_token: str | list[list[str]] | None = ...,
9595
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
9696
athena_query_wait_polling_delay: float = ...,
9797
data_source: str | None = ...,
@@ -113,7 +113,7 @@ def start_query_executions(
113113
params: dict[str, Any] | list[str] | None = ...,
114114
paramstyle: Literal["qmark", "named"] = ...,
115115
boto3_session: boto3.Session | None = ...,
116-
client_request_token: str | list[str] | None = ...,
116+
client_request_token: str | list[list[str]] | None = ...,
117117
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
118118
athena_query_wait_polling_delay: float = ...,
119119
data_source: str | None = ...,

0 commit comments

Comments
 (0)