Skip to content

Commit 17cd753

Browse files
committed
feat(athena): add start_query_executions for async multi-query execution
Introduce `wr.athena.start_query_executions` as a parallelized variant of `start_query_execution`. It allows submitting multiple queries in one call, with support for: - Sequential or threaded submission (`use_threads`) - Lazy or eager consumption of results (`as_iterator`) - Per-query `client_request_token` (string or list) - Optional workgroup checks (`check_workgroup`, `enforce_workgroup`) - Full Athena cache integration This improves performance when dispatching batches of queries by reducing workgroup lookups and enabling concurrent execution.
1 parent c66e652 commit 17cd753

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed

awswrangler/athena/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
get_query_execution,
55
stop_query_execution,
66
start_query_execution,
7+
start_query_executions,
78
wait_query,
89
)
910
from awswrangler.athena._spark import create_spark_session, run_spark_calculation
@@ -53,6 +54,7 @@
5354
"create_ctas_table",
5455
"show_create_table",
5556
"start_query_execution",
57+
"start_query_executions",
5658
"stop_query_execution",
5759
"unload",
5860
"wait_query",

awswrangler/athena/_executions.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
cast,
1111
)
1212

13+
import os
1314
import boto3
1415
import botocore
1516
from typing_extensions import Literal
1617

18+
from concurrent.futures import ThreadPoolExecutor
1719
from awswrangler import _utils, exceptions, typing
1820
from awswrangler._config import apply_configs
21+
from functools import reduce
1922

2023
from ._cache import _CacheInfo, _check_for_cached_results
2124
from ._utils import (
@@ -168,6 +171,195 @@ def start_query_execution(
168171

169172
return query_execution_id
170173

174+
@apply_configs
175+
def start_query_executions(
176+
sqls: list[str],
177+
database: str | None = None,
178+
s3_output: str | None = None,
179+
workgroup: str = "primary",
180+
encryption: str | None = None,
181+
kms_key: str | None = None,
182+
params: dict[str, Any] | list[str] | None = None,
183+
paramstyle: Literal["qmark", "named"] = "named",
184+
boto3_session: boto3.Session | None = None,
185+
client_request_token: str | None = None,
186+
athena_cache_settings: typing.AthenaCacheSettings | None = None,
187+
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
188+
data_source: str | None = None,
189+
wait: bool = False,
190+
check_workgroup: bool = True,
191+
enforce_workgroup: bool = False,
192+
as_iterator: bool = False,
193+
use_threads: bool | int = False
194+
) -> list[str] | list[dict[str, Any]]:
195+
"""
196+
Start multiple SQL queries against Amazon Athena.
197+
198+
This function is the multi-query variant of ``start_query_execution``.
199+
It supports caching, idempotent request tokens, workgroup configuration,
200+
sequential or parallel execution, and lazy or eager iteration.
201+
202+
Parameters
203+
----------
204+
sqls : list[str]
205+
List of SQL queries to execute.
206+
database : str, optional
207+
AWS Glue/Athena database name.
208+
s3_output : str, optional
209+
S3 path where query results will be stored.
210+
workgroup : str, default 'primary'
211+
Athena workgroup name.
212+
encryption : str, optional
213+
One of {'SSE_S3', 'SSE_KMS', 'CSE_KMS'}.
214+
kms_key : str, optional
215+
KMS key ARN/ID, required if using KMS-based encryption.
216+
params : dict or list, optional
217+
Query parameters. Behavior depends on ``paramstyle``.
218+
paramstyle : {'named', 'qmark'}, default 'named'
219+
Parameter substitution style:
220+
- 'named': ``{"name": "value"}`` and query must use ``:name``.
221+
- 'qmark': list of values, substituted sequentially.
222+
boto3_session : boto3.Session, optional
223+
Existing boto3 session. A new session will be created if None.
224+
client_request_token : str | list[str], optional
225+
Idempotency token(s) for Athena:
226+
- If a string: suffixed with an index to generate unique tokens.
227+
- If a list: must have same length as ``sqls``.
228+
- If None: no token provided (duplicate submissions possible).
229+
Tokens are padded/truncated to comply with Athena’s requirement (32–128 chars).
230+
athena_cache_settings : dict, optional
231+
Wrangler cache settings to reuse results when possible.
232+
athena_query_wait_polling_delay : float, default 1.0
233+
Interval in seconds between query status checks when waiting.
234+
data_source : str, optional
235+
Data catalog name (default 'AwsDataCatalog').
236+
wait : bool, default False
237+
If True, block until queries complete and return their execution details.
238+
If False, return query IDs immediately.
239+
check_workgroup : bool, default True
240+
If True, call GetWorkGroup once to retrieve workgroup configuration.
241+
If False, build a workgroup config from provided parameters (faster, fewer API calls).
242+
enforce_workgroup : bool, default False
243+
If True, mark the dummy workgroup config as "enforced" when skipping GetWorkGroup.
244+
as_iterator : bool, default False
245+
If True, return a lazy iterator instead of a list.
246+
use_threads : bool | int, default False
247+
Controls parallelism:
248+
- False: submit queries sequentially.
249+
- True: use ``os.cpu_count()`` worker threads.
250+
- int: number of worker threads to use.
251+
252+
Returns
253+
-------
254+
list[str] | list[dict[str, Any]] | Iterator
255+
- If ``wait=False``: list or iterator of query execution IDs.
256+
- If ``wait=True``: list or iterator of query execution metadata dicts.
257+
258+
Examples
259+
--------
260+
Sequential, no wait:
261+
>>> qids = wr.athena.start_query_executions(
262+
... sqls=["SELECT 1", "SELECT 2"],
263+
... database="default",
264+
... s3_output="s3://my-bucket/results/",
265+
... )
266+
>>> print(list(qids))
267+
['abc-123...', 'def-456...']
268+
269+
Parallel execution with 8 threads:
270+
>>> qids = wr.athena.start_query_executions(
271+
... sqls=["SELECT 1", "SELECT 2", "SELECT 3"],
272+
... database="default",
273+
... s3_output="s3://my-bucket/results/",
274+
... use_threads=8,
275+
... )
276+
277+
Waiting for completion and retrieving metadata:
278+
>>> results = wr.athena.start_query_executions(
279+
... sqls=["SELECT 1"],
280+
... database="default",
281+
... s3_output="s3://my-bucket/results/",
282+
... wait=True
283+
... )
284+
>>> print(results[0]["Status"]["State"])
285+
'SUCCEEDED'
286+
"""
287+
288+
session = boto3_session or boto3.Session()
289+
client = session.client("athena")
290+
291+
if isinstance(client_request_token, list):
292+
if len(client_request_token) != len(sqls):
293+
raise ValueError("Length of client_request_token list must match number of queries in sqls")
294+
tokens = client_request_token
295+
elif isinstance(client_request_token, str):
296+
tokens = [f"{client_request_token}-{i}".ljust(32, "x")[:128] for i in range(len(sqls))]
297+
else:
298+
tokens = [None] * len(sqls)
299+
300+
formatted_queries = list(map(lambda q: _apply_formatter(q, params, paramstyle), sqls))
301+
302+
if check_workgroup:
303+
wg_config: _WorkGroupConfig = _utils._get_workgroup_config(session=session, workgroup=workgroup)
304+
else:
305+
wg_config = _WorkGroupConfig(
306+
enforced=enforce_workgroup,
307+
s3_output=s3_output,
308+
encryption=encryption,
309+
kms_key=kms_key,
310+
)
311+
312+
def _submit(item):
313+
(q, execution_params), token = item
314+
315+
if token is None and athena_cache_settings is not None:
316+
cache_info = _executions._check_for_cached_results(
317+
sql=q,
318+
boto3_session=session,
319+
workgroup=workgroup,
320+
athena_cache_settings=athena_cache_settings,
321+
)
322+
if cache_info.has_valid_cache and cache_info.query_execution_id is not None:
323+
return cache_info.query_execution_id
324+
325+
return _start_query_execution(
326+
sql=q,
327+
wg_config=wg_config,
328+
database=database,
329+
data_source=data_source,
330+
s3_output=s3_output,
331+
workgroup=workgroup,
332+
encryption=encryption,
333+
kms_key=kms_key,
334+
execution_params=execution_params,
335+
client_request_token=token,
336+
boto3_session=session,
337+
)
338+
339+
items = list(zip(formatted_queries, tokens))
340+
341+
if use_threads is False:
342+
query_ids = map(_submit, items)
343+
else:
344+
max_workers = (
345+
os.cpu_count() or 4 if use_threads is True else int(use_threads)
346+
)
347+
executor = ThreadPoolExecutor(max_workers=max_workers)
348+
query_ids = executor.map(_submit, items)
349+
350+
if wait:
351+
results_iter = map(
352+
lambda qid: wait_query(
353+
query_execution_id=qid,
354+
boto3_session=session,
355+
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
356+
),
357+
query_ids,
358+
)
359+
return results_iter if as_iterator else list(results_iter)
360+
361+
return query_ids if as_iterator else list(query_ids)
362+
171363

172364
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None:
173365
"""Stop a query execution.

0 commit comments

Comments
 (0)