diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index 0556272dc..380bcf627 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -4,6 +4,7 @@ get_query_execution, stop_query_execution, start_query_execution, + start_query_executions, wait_query, ) from awswrangler.athena._spark import create_spark_session, run_spark_calculation @@ -53,6 +54,7 @@ "create_ctas_table", "show_create_table", "start_query_execution", + "start_query_executions", "stop_query_execution", "unload", "wait_query", diff --git a/awswrangler/athena/_executions.py b/awswrangler/athena/_executions.py index b2d3f518a..ce3fb5e98 100644 --- a/awswrangler/athena/_executions.py +++ b/awswrangler/athena/_executions.py @@ -3,7 +3,9 @@ from __future__ import annotations import logging +import os import time +from concurrent.futures import ThreadPoolExecutor from typing import ( Any, Dict, @@ -29,6 +31,8 @@ _logger: logging.Logger = logging.getLogger(__name__) +_DEFAULT_MAX_WORKERS = max(4, os.cpu_count() or 4) + @apply_configs def start_query_execution( @@ -169,6 +173,183 @@ def start_query_execution( return query_execution_id +@apply_configs +def start_query_executions( + sqls: list[str], + database: str | None = None, + s3_output: str | None = None, + workgroup: str = "primary", + encryption: str | None = None, + kms_key: str | None = None, + params: dict[str, typing.Any] | list[str] | None = None, + paramstyle: Literal["qmark", "named"] = "named", + boto3_session: boto3.Session | None = None, + client_request_token: str | list[list[str]] | None = None, + athena_cache_settings: typing.AthenaCacheSettings | None = None, + athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, + data_source: str | None = None, + wait: bool = False, + check_workgroup: bool = True, + enforce_workgroup: bool = False, + as_iterator: bool = False, + use_threads: bool | int = False, +) -> list[str] | list[dict[str, typing.Any]]: + """ + Start multiple SQL queries against Amazon Athena. + + This is the multi-query counterpart to ``start_query_execution``. It supports + per-query caching and idempotent client request tokens, optional workgroup + validation/enforcement, sequential or thread-pooled parallel dispatch, and + either eager (list) or lazy (iterator) consumption. If ``wait=True``, each + query may be awaited to completion within its submission thread. + + Parameters + ---------- + sqls : list[str] + List of SQL queries to execute. + database : str, optional + AWS Glue/Athena database name. + s3_output : str, optional + S3 path where query results will be stored. + workgroup : str, default 'primary' + Athena workgroup name. + encryption : str, optional + One of {'SSE_S3', 'SSE_KMS', 'CSE_KMS'}. + kms_key : str, optional + KMS key ARN/ID, required if using KMS-based encryption. + params : dict or list, optional + Query parameters. Behavior depends on ``paramstyle``. + paramstyle : {'named', 'qmark'}, default 'named' + Parameter substitution style. + boto3_session : boto3.Session, optional + Existing boto3 session. A new session will be created if None. + client_request_token : str | list[str], optional + Idempotency token(s). If a string, suffixed with query index. + athena_cache_settings : dict, optional + Wrangler cache settings for query result reuse. + athena_query_wait_polling_delay : float, default 1.0 + Interval between status checks when waiting for queries. + data_source : str, optional + Data catalog name (default 'AwsDataCatalog'). + wait : bool, default False + If True, block until each query completes. + check_workgroup : bool, default True + If True, fetch workgroup config from Athena. + enforce_workgroup : bool, default False + If True, enforce workgroup config even when skipping fetch. + as_iterator : bool, default False + If True, return an iterator instead of a list. + use_threads : bool | int, default False + Parallelism: + - False: sequential execution + - True: ``os.cpu_count()`` threads + - int: number of worker threads + + Returns + ------- + list[str] | list[dict] | Iterator + QueryExecutionIds or execution metadata dicts if ``wait=True``. + """ + session = boto3_session or boto3.Session() + + if isinstance(client_request_token, list): + if len(client_request_token) != len(sqls): + raise ValueError("Length of client_request_token list must match number of queries in sqls") + tokens = client_request_token + elif isinstance(client_request_token, str): + tokens = (f"{client_request_token}-{i}" for i in range(len(sqls))) + else: + tokens = [None] * len(sqls) + + if paramstyle == "named": + formatted_queries = (_apply_formatter(q, params, "named") for q in sqls) + elif paramstyle == "qmark": + _params_list = params or [None] * len(sqls) + formatted_queries = (_apply_formatter(q, query_params, "qmark") for q, query_params in zip(sqls, _params_list)) + else: + raise ValueError("paramstyle must be 'named' or 'qmark'") + + if check_workgroup: + wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup) + else: + wg_config = _WorkGroupConfig( + enforced=enforce_workgroup, + s3_output=s3_output, + encryption=encryption, + kms_key=kms_key, + ) + + def _submit(item: tuple[tuple[str, list[str] | None], str | None]): + (q, execution_params), token = item + _logger.debug("Executing query:\n%s", q) + + if token is None and athena_cache_settings is not None: + cache_info = _check_for_cached_results( + sql=q, + boto3_session=session, + workgroup=workgroup, + athena_cache_settings=athena_cache_settings, + ) + _logger.debug("Cache info:\n%s", cache_info) + if cache_info.has_valid_cache and cache_info.query_execution_id is not None: + _logger.debug("Valid cache found. Retrieving...") + return ( + wait_query( + query_execution_id=cache_info.query_execution_id, + boto3_session=session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + ) + if wait + else cache_info.query_execution_id + ) + + qid = _start_query_execution( + sql=q, + wg_config=wg_config, + database=database, + data_source=data_source, + s3_output=s3_output, + workgroup=workgroup, + encryption=encryption, + kms_key=kms_key, + execution_params=execution_params, + client_request_token=token, + boto3_session=session, + ) + + if wait: + return wait_query( + query_execution_id=qid, + boto3_session=session, + athena_query_wait_polling_delay=athena_query_wait_polling_delay, + ) + + return qid + + items = zip(formatted_queries, tokens) + + if use_threads is False: + results = map(_submit, items) + return results if as_iterator else list(results) + + max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads) + + if as_iterator: + executor = ThreadPoolExecutor(max_workers=max_workers) + it = executor.map(_submit, items) + + def _iter(): + try: + yield from it + finally: + executor.shutdown(wait=True) + + return _iter() + else: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + return list(executor.map(_submit, items)) + + def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None: """Stop a query execution. diff --git a/awswrangler/athena/_executions.pyi b/awswrangler/athena/_executions.pyi index 5a394d916..585c28d9c 100644 --- a/awswrangler/athena/_executions.pyi +++ b/awswrangler/athena/_executions.pyi @@ -58,6 +58,71 @@ def start_query_execution( data_source: str | None = ..., wait: bool, ) -> str | dict[str, Any]: ... +@overload +def start_query_executions( + sqls: list[str], + database: str | None = ..., + s3_output: str | None = ..., + workgroup: str = ..., + encryption: str | None = ..., + kms_key: str | None = ..., + params: dict[str, Any] | list[str] | None = ..., + paramstyle: Literal["qmark", "named"] = ..., + boto3_session: boto3.Session | None = ..., + client_request_token: str | list[list[str]] | None = ..., + athena_cache_settings: typing.AthenaCacheSettings | None = ..., + athena_query_wait_polling_delay: float = ..., + data_source: str | None = ..., + wait: Literal[False] = ..., + check_workgroup: bool = ..., + enforce_workgroup: bool = ..., + as_iterator: bool = ..., + use_threads: bool | int = ..., +) -> list[str]: ... +@overload +def start_query_executions( + sqls: list[str], + *, + database: str | None = ..., + s3_output: str | None = ..., + workgroup: str = ..., + encryption: str | None = ..., + kms_key: str | None = ..., + params: dict[str, Any] | list[str] | None = ..., + paramstyle: Literal["qmark", "named"] = ..., + boto3_session: boto3.Session | None = ..., + client_request_token: str | list[list[str]] | None = ..., + athena_cache_settings: typing.AthenaCacheSettings | None = ..., + athena_query_wait_polling_delay: float = ..., + data_source: str | None = ..., + wait: Literal[True], + check_workgroup: bool = ..., + enforce_workgroup: bool = ..., + as_iterator: bool = ..., + use_threads: bool | int = ..., +) -> list[dict[str, Any]]: ... +@overload +def start_query_executions( + sqls: list[str], + *, + database: str | None = ..., + s3_output: str | None = ..., + workgroup: str = ..., + encryption: str | None = ..., + kms_key: str | None = ..., + params: dict[str, Any] | list[str] | None = ..., + paramstyle: Literal["qmark", "named"] = ..., + boto3_session: boto3.Session | None = ..., + client_request_token: str | list[list[str]] | None = ..., + athena_cache_settings: typing.AthenaCacheSettings | None = ..., + athena_query_wait_polling_delay: float = ..., + data_source: str | None = ..., + wait: bool, + check_workgroup: bool = ..., + enforce_workgroup: bool = ..., + as_iterator: bool = ..., + use_threads: bool | int = ..., +) -> list[str] | list[dict[str, Any]]: ... def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = ...) -> None: ... def wait_query( query_execution_id: str, diff --git a/tests/unit/test_athena.py b/tests/unit/test_athena.py index d747ae001..1376df356 100644 --- a/tests/unit/test_athena.py +++ b/tests/unit/test_athena.py @@ -1708,3 +1708,62 @@ def test_athena_date_recovery(path, glue_database, glue_table): ctas_approach=False, ) assert pandas_equals(df, df2) + + +def test_start_query_executions_ids_and_results(path, glue_database, glue_table): + # Prepare table + wr.s3.to_parquet( + df=get_df(), + path=path, + index=True, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + sqls = [ + f"SELECT * FROM {glue_table} LIMIT 1", + f"SELECT COUNT(*) FROM {glue_table}", + ] + + # Case 1: Sequential, return query IDs + qids = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, use_threads=False) + assert isinstance(qids, list) + assert all(isinstance(qid, str) for qid in qids) + assert len(qids) == len(sqls) + + # Case 2: Sequential, wait for results + results = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=False) + assert isinstance(results, list) + assert all(isinstance(r, dict) for r in results) + assert all("Status" in r for r in results) + + # Case 3: Parallel execution with threads + results_parallel = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=True) + assert isinstance(results_parallel, list) + assert all(isinstance(r, dict) for r in results_parallel) + + +def test_start_query_executions_as_iterator(path, glue_database, glue_table): + # Prepare table + wr.s3.to_parquet( + df=get_df(), + path=path, + index=True, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + sqls = [f"SELECT * FROM {glue_table} LIMIT 1"] + + # Case: as_iterator=True should return a generator-like object + qids_iter = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, as_iterator=True) + assert not isinstance(qids_iter, list) + qids = list(qids_iter) + assert len(qids) == 1 + assert isinstance(qids[0], str)