Skip to content

Commit b6e4d88

Browse files
committed
feat(athena): improve start_query_executions with simplified tokens and parallel wait
- Simplified client_request_token handling: - Removed manual padding/truncation. - Let Athena enforce length constraints. - Tokens generated as `<base_token>-<index>` or provided as list. - Improved wait logic: - Added optional wait handling directly inside _submit. - Queries can now be waited in parallel with submission (reduced overhead). - Configurable default threads: - Replaced hardcoded defaults with os.cpu_count(). - Added support for AWSWRANGLER_THREADS_DEFAULT env var override.
1 parent 17cd753 commit b6e4d88

File tree

1 file changed

+52
-87
lines changed

1 file changed

+52
-87
lines changed

awswrangler/athena/_executions.py

Lines changed: 52 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
_logger: logging.Logger = logging.getLogger(__name__)
3434

35+
_DEFAULT_MAX_WORKERS = max(4, os.cpu_count() or 4)
3536

3637
@apply_configs
3738
def start_query_execution(
@@ -179,25 +180,25 @@ def start_query_executions(
179180
workgroup: str = "primary",
180181
encryption: str | None = None,
181182
kms_key: str | None = None,
182-
params: dict[str, Any] | list[str] | None = None,
183+
params: dict[str, typing.Any] | list[str] | None = None,
183184
paramstyle: Literal["qmark", "named"] = "named",
184185
boto3_session: boto3.Session | None = None,
185-
client_request_token: str | None = None,
186+
client_request_token: str | list[str] | None = None,
186187
athena_cache_settings: typing.AthenaCacheSettings | None = None,
187-
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
188+
athena_query_wait_polling_delay: float = 1.0,
188189
data_source: str | None = None,
189190
wait: bool = False,
190191
check_workgroup: bool = True,
191192
enforce_workgroup: bool = False,
192193
as_iterator: bool = False,
193-
use_threads: bool | int = False
194-
) -> list[str] | list[dict[str, Any]]:
194+
use_threads: bool | int = False,
195+
) -> list[str] | list[dict[str, typing.Any]]:
195196
"""
196197
Start multiple SQL queries against Amazon Athena.
197198
198-
This function is the multi-query variant of ``start_query_execution``.
199-
It supports caching, idempotent request tokens, workgroup configuration,
200-
sequential or parallel execution, and lazy or eager iteration.
199+
Each query can optionally use Athena's result cache and idempotent request tokens.
200+
Submissions can be sequential or parallel, and each query can be waited on
201+
individually (inside its submission thread) if ``wait=True``.
201202
202203
Parameters
203204
----------
@@ -216,91 +217,51 @@ def start_query_executions(
216217
params : dict or list, optional
217218
Query parameters. Behavior depends on ``paramstyle``.
218219
paramstyle : {'named', 'qmark'}, default 'named'
219-
Parameter substitution style:
220-
- 'named': ``{"name": "value"}`` and query must use ``:name``.
221-
- 'qmark': list of values, substituted sequentially.
220+
Parameter substitution style.
222221
boto3_session : boto3.Session, optional
223222
Existing boto3 session. A new session will be created if None.
224223
client_request_token : str | list[str], optional
225-
Idempotency token(s) for Athena:
226-
- If a string: suffixed with an index to generate unique tokens.
227-
- If a list: must have same length as ``sqls``.
228-
- If None: no token provided (duplicate submissions possible).
229-
Tokens are padded/truncated to comply with Athena’s requirement (32–128 chars).
224+
Idempotency token(s). If a string, suffixed with query index.
230225
athena_cache_settings : dict, optional
231-
Wrangler cache settings to reuse results when possible.
226+
Wrangler cache settings for query result reuse.
232227
athena_query_wait_polling_delay : float, default 1.0
233-
Interval in seconds between query status checks when waiting.
228+
Interval between status checks when waiting for queries.
234229
data_source : str, optional
235230
Data catalog name (default 'AwsDataCatalog').
236231
wait : bool, default False
237-
If True, block until queries complete and return their execution details.
238-
If False, return query IDs immediately.
232+
If True, block until each query completes.
239233
check_workgroup : bool, default True
240-
If True, call GetWorkGroup once to retrieve workgroup configuration.
241-
If False, build a workgroup config from provided parameters (faster, fewer API calls).
234+
If True, fetch workgroup config from Athena.
242235
enforce_workgroup : bool, default False
243-
If True, mark the dummy workgroup config as "enforced" when skipping GetWorkGroup.
236+
If True, enforce workgroup config even when skipping fetch.
244237
as_iterator : bool, default False
245-
If True, return a lazy iterator instead of a list.
238+
If True, return an iterator instead of a list.
246239
use_threads : bool | int, default False
247-
Controls parallelism:
248-
- False: submit queries sequentially.
249-
- True: use ``os.cpu_count()`` worker threads.
250-
- int: number of worker threads to use.
240+
Parallelism:
241+
- False: sequential execution
242+
- True: ``os.cpu_count()`` threads
243+
- int: number of worker threads
251244
252245
Returns
253246
-------
254-
list[str] | list[dict[str, Any]] | Iterator
255-
- If ``wait=False``: list or iterator of query execution IDs.
256-
- If ``wait=True``: list or iterator of query execution metadata dicts.
257-
258-
Examples
259-
--------
260-
Sequential, no wait:
261-
>>> qids = wr.athena.start_query_executions(
262-
... sqls=["SELECT 1", "SELECT 2"],
263-
... database="default",
264-
... s3_output="s3://my-bucket/results/",
265-
... )
266-
>>> print(list(qids))
267-
['abc-123...', 'def-456...']
268-
269-
Parallel execution with 8 threads:
270-
>>> qids = wr.athena.start_query_executions(
271-
... sqls=["SELECT 1", "SELECT 2", "SELECT 3"],
272-
... database="default",
273-
... s3_output="s3://my-bucket/results/",
274-
... use_threads=8,
275-
... )
276-
277-
Waiting for completion and retrieving metadata:
278-
>>> results = wr.athena.start_query_executions(
279-
... sqls=["SELECT 1"],
280-
... database="default",
281-
... s3_output="s3://my-bucket/results/",
282-
... wait=True
283-
... )
284-
>>> print(results[0]["Status"]["State"])
285-
'SUCCEEDED'
247+
list[str] | list[dict] | Iterator
248+
QueryExecutionIds or execution metadata dicts if ``wait=True``.
286249
"""
287-
288250
session = boto3_session or boto3.Session()
289-
client = session.client("athena")
290251

291252
if isinstance(client_request_token, list):
292253
if len(client_request_token) != len(sqls):
293254
raise ValueError("Length of client_request_token list must match number of queries in sqls")
294255
tokens = client_request_token
295256
elif isinstance(client_request_token, str):
296-
tokens = [f"{client_request_token}-{i}".ljust(32, "x")[:128] for i in range(len(sqls))]
257+
tokens = [f"{client_request_token}-{i}" for i in range(len(sqls))]
297258
else:
298259
tokens = [None] * len(sqls)
299260

300261
formatted_queries = list(map(lambda q: _apply_formatter(q, params, paramstyle), sqls))
301262

302263
if check_workgroup:
303-
wg_config: _WorkGroupConfig = _utils._get_workgroup_config(session=session, workgroup=workgroup)
264+
wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup)
304265
else:
305266
wg_config = _WorkGroupConfig(
306267
enforced=enforce_workgroup,
@@ -309,20 +270,28 @@ def start_query_executions(
309270
kms_key=kms_key,
310271
)
311272

312-
def _submit(item):
273+
def _submit(item: tuple[tuple[str, list[str] | None], str | None]):
313274
(q, execution_params), token = item
314275

315276
if token is None and athena_cache_settings is not None:
316-
cache_info = _executions._check_for_cached_results(
277+
cache_info = _check_for_cached_results(
317278
sql=q,
318279
boto3_session=session,
319280
workgroup=workgroup,
320281
athena_cache_settings=athena_cache_settings,
321282
)
322283
if cache_info.has_valid_cache and cache_info.query_execution_id is not None:
323-
return cache_info.query_execution_id
324-
325-
return _start_query_execution(
284+
return (
285+
wait_query(
286+
query_execution_id=cache_info.query_execution_id,
287+
boto3_session=session,
288+
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
289+
)
290+
if wait
291+
else cache_info.query_execution_id
292+
)
293+
294+
qid = _start_query_execution(
326295
sql=q,
327296
wg_config=wg_config,
328297
database=database,
@@ -336,29 +305,25 @@ def _submit(item):
336305
boto3_session=session,
337306
)
338307

308+
if wait:
309+
return wait_query(
310+
query_execution_id=qid,
311+
boto3_session=session,
312+
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
313+
)
314+
315+
return qid
316+
339317
items = list(zip(formatted_queries, tokens))
340318

341319
if use_threads is False:
342-
query_ids = map(_submit, items)
320+
results = map(_submit, items)
343321
else:
344-
max_workers = (
345-
os.cpu_count() or 4 if use_threads is True else int(use_threads)
346-
)
347-
executor = ThreadPoolExecutor(max_workers=max_workers)
348-
query_ids = executor.map(_submit, items)
349-
350-
if wait:
351-
results_iter = map(
352-
lambda qid: wait_query(
353-
query_execution_id=qid,
354-
boto3_session=session,
355-
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
356-
),
357-
query_ids,
358-
)
359-
return results_iter if as_iterator else list(results_iter)
322+
max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads)
323+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
324+
results = executor.map(_submit, items)
360325

361-
return query_ids if as_iterator else list(query_ids)
326+
return results if as_iterator else list(results)
362327

363328

364329
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None:

0 commit comments

Comments
 (0)