|
10 | 10 | cast, |
11 | 11 | ) |
12 | 12 |
|
| 13 | +import os |
13 | 14 | import boto3 |
14 | 15 | import botocore |
15 | 16 | from typing_extensions import Literal |
16 | 17 |
|
| 18 | +from concurrent.futures import ThreadPoolExecutor |
17 | 19 | from awswrangler import _utils, exceptions, typing |
18 | 20 | from awswrangler._config import apply_configs |
| 21 | +from functools import reduce |
19 | 22 |
|
20 | 23 | from ._cache import _CacheInfo, _check_for_cached_results |
21 | 24 | from ._utils import ( |
@@ -168,6 +171,195 @@ def start_query_execution( |
168 | 171 |
|
169 | 172 | return query_execution_id |
170 | 173 |
|
| 174 | +@apply_configs |
| 175 | +def start_query_executions( |
| 176 | + sqls: list[str], |
| 177 | + database: str | None = None, |
| 178 | + s3_output: str | None = None, |
| 179 | + workgroup: str = "primary", |
| 180 | + encryption: str | None = None, |
| 181 | + kms_key: str | None = None, |
| 182 | + params: dict[str, Any] | list[str] | None = None, |
| 183 | + paramstyle: Literal["qmark", "named"] = "named", |
| 184 | + boto3_session: boto3.Session | None = None, |
| 185 | + client_request_token: str | None = None, |
| 186 | + athena_cache_settings: typing.AthenaCacheSettings | None = None, |
| 187 | + athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY, |
| 188 | + data_source: str | None = None, |
| 189 | + wait: bool = False, |
| 190 | + check_workgroup: bool = True, |
| 191 | + enforce_workgroup: bool = False, |
| 192 | + as_iterator: bool = False, |
| 193 | + use_threads: bool | int = False |
| 194 | +) -> list[str] | list[dict[str, Any]]: |
| 195 | + """ |
| 196 | + Start multiple SQL queries against Amazon Athena. |
| 197 | +
|
| 198 | + This function is the multi-query variant of ``start_query_execution``. |
| 199 | + It supports caching, idempotent request tokens, workgroup configuration, |
| 200 | + sequential or parallel execution, and lazy or eager iteration. |
| 201 | +
|
| 202 | + Parameters |
| 203 | + ---------- |
| 204 | + sqls : list[str] |
| 205 | + List of SQL queries to execute. |
| 206 | + database : str, optional |
| 207 | + AWS Glue/Athena database name. |
| 208 | + s3_output : str, optional |
| 209 | + S3 path where query results will be stored. |
| 210 | + workgroup : str, default 'primary' |
| 211 | + Athena workgroup name. |
| 212 | + encryption : str, optional |
| 213 | + One of {'SSE_S3', 'SSE_KMS', 'CSE_KMS'}. |
| 214 | + kms_key : str, optional |
| 215 | + KMS key ARN/ID, required if using KMS-based encryption. |
| 216 | + params : dict or list, optional |
| 217 | + Query parameters. Behavior depends on ``paramstyle``. |
| 218 | + paramstyle : {'named', 'qmark'}, default 'named' |
| 219 | + Parameter substitution style: |
| 220 | + - 'named': ``{"name": "value"}`` and query must use ``:name``. |
| 221 | + - 'qmark': list of values, substituted sequentially. |
| 222 | + boto3_session : boto3.Session, optional |
| 223 | + Existing boto3 session. A new session will be created if None. |
| 224 | + client_request_token : str | list[str], optional |
| 225 | + Idempotency token(s) for Athena: |
| 226 | + - If a string: suffixed with an index to generate unique tokens. |
| 227 | + - If a list: must have same length as ``sqls``. |
| 228 | + - If None: no token provided (duplicate submissions possible). |
| 229 | + Tokens are padded/truncated to comply with Athena’s requirement (32–128 chars). |
| 230 | + athena_cache_settings : dict, optional |
| 231 | + Wrangler cache settings to reuse results when possible. |
| 232 | + athena_query_wait_polling_delay : float, default 1.0 |
| 233 | + Interval in seconds between query status checks when waiting. |
| 234 | + data_source : str, optional |
| 235 | + Data catalog name (default 'AwsDataCatalog'). |
| 236 | + wait : bool, default False |
| 237 | + If True, block until queries complete and return their execution details. |
| 238 | + If False, return query IDs immediately. |
| 239 | + check_workgroup : bool, default True |
| 240 | + If True, call GetWorkGroup once to retrieve workgroup configuration. |
| 241 | + If False, build a workgroup config from provided parameters (faster, fewer API calls). |
| 242 | + enforce_workgroup : bool, default False |
| 243 | + If True, mark the dummy workgroup config as "enforced" when skipping GetWorkGroup. |
| 244 | + as_iterator : bool, default False |
| 245 | + If True, return a lazy iterator instead of a list. |
| 246 | + use_threads : bool | int, default False |
| 247 | + Controls parallelism: |
| 248 | + - False: submit queries sequentially. |
| 249 | + - True: use ``os.cpu_count()`` worker threads. |
| 250 | + - int: number of worker threads to use. |
| 251 | +
|
| 252 | + Returns |
| 253 | + ------- |
| 254 | + list[str] | list[dict[str, Any]] | Iterator |
| 255 | + - If ``wait=False``: list or iterator of query execution IDs. |
| 256 | + - If ``wait=True``: list or iterator of query execution metadata dicts. |
| 257 | +
|
| 258 | + Examples |
| 259 | + -------- |
| 260 | + Sequential, no wait: |
| 261 | + >>> qids = wr.athena.start_query_executions( |
| 262 | + ... sqls=["SELECT 1", "SELECT 2"], |
| 263 | + ... database="default", |
| 264 | + ... s3_output="s3://my-bucket/results/", |
| 265 | + ... ) |
| 266 | + >>> print(list(qids)) |
| 267 | + ['abc-123...', 'def-456...'] |
| 268 | +
|
| 269 | + Parallel execution with 8 threads: |
| 270 | + >>> qids = wr.athena.start_query_executions( |
| 271 | + ... sqls=["SELECT 1", "SELECT 2", "SELECT 3"], |
| 272 | + ... database="default", |
| 273 | + ... s3_output="s3://my-bucket/results/", |
| 274 | + ... use_threads=8, |
| 275 | + ... ) |
| 276 | +
|
| 277 | + Waiting for completion and retrieving metadata: |
| 278 | + >>> results = wr.athena.start_query_executions( |
| 279 | + ... sqls=["SELECT 1"], |
| 280 | + ... database="default", |
| 281 | + ... s3_output="s3://my-bucket/results/", |
| 282 | + ... wait=True |
| 283 | + ... ) |
| 284 | + >>> print(results[0]["Status"]["State"]) |
| 285 | + 'SUCCEEDED' |
| 286 | + """ |
| 287 | + |
| 288 | + session = boto3_session or boto3.Session() |
| 289 | + client = session.client("athena") |
| 290 | + |
| 291 | + if isinstance(client_request_token, list): |
| 292 | + if len(client_request_token) != len(sqls): |
| 293 | + raise ValueError("Length of client_request_token list must match number of queries in sqls") |
| 294 | + tokens = client_request_token |
| 295 | + elif isinstance(client_request_token, str): |
| 296 | + tokens = [f"{client_request_token}-{i}".ljust(32, "x")[:128] for i in range(len(sqls))] |
| 297 | + else: |
| 298 | + tokens = [None] * len(sqls) |
| 299 | + |
| 300 | + formatted_queries = list(map(lambda q: _apply_formatter(q, params, paramstyle), sqls)) |
| 301 | + |
| 302 | + if check_workgroup: |
| 303 | + wg_config: _WorkGroupConfig = _utils._get_workgroup_config(session=session, workgroup=workgroup) |
| 304 | + else: |
| 305 | + wg_config = _WorkGroupConfig( |
| 306 | + enforced=enforce_workgroup, |
| 307 | + s3_output=s3_output, |
| 308 | + encryption=encryption, |
| 309 | + kms_key=kms_key, |
| 310 | + ) |
| 311 | + |
| 312 | + def _submit(item): |
| 313 | + (q, execution_params), token = item |
| 314 | + |
| 315 | + if token is None and athena_cache_settings is not None: |
| 316 | + cache_info = _executions._check_for_cached_results( |
| 317 | + sql=q, |
| 318 | + boto3_session=session, |
| 319 | + workgroup=workgroup, |
| 320 | + athena_cache_settings=athena_cache_settings, |
| 321 | + ) |
| 322 | + if cache_info.has_valid_cache and cache_info.query_execution_id is not None: |
| 323 | + return cache_info.query_execution_id |
| 324 | + |
| 325 | + return _start_query_execution( |
| 326 | + sql=q, |
| 327 | + wg_config=wg_config, |
| 328 | + database=database, |
| 329 | + data_source=data_source, |
| 330 | + s3_output=s3_output, |
| 331 | + workgroup=workgroup, |
| 332 | + encryption=encryption, |
| 333 | + kms_key=kms_key, |
| 334 | + execution_params=execution_params, |
| 335 | + client_request_token=token, |
| 336 | + boto3_session=session, |
| 337 | + ) |
| 338 | + |
| 339 | + items = list(zip(formatted_queries, tokens)) |
| 340 | + |
| 341 | + if use_threads is False: |
| 342 | + query_ids = map(_submit, items) |
| 343 | + else: |
| 344 | + max_workers = ( |
| 345 | + os.cpu_count() or 4 if use_threads is True else int(use_threads) |
| 346 | + ) |
| 347 | + executor = ThreadPoolExecutor(max_workers=max_workers) |
| 348 | + query_ids = executor.map(_submit, items) |
| 349 | + |
| 350 | + if wait: |
| 351 | + results_iter = map( |
| 352 | + lambda qid: wait_query( |
| 353 | + query_execution_id=qid, |
| 354 | + boto3_session=session, |
| 355 | + athena_query_wait_polling_delay=athena_query_wait_polling_delay, |
| 356 | + ), |
| 357 | + query_ids, |
| 358 | + ) |
| 359 | + return results_iter if as_iterator else list(results_iter) |
| 360 | + |
| 361 | + return query_ids if as_iterator else list(query_ids) |
| 362 | + |
171 | 363 |
|
172 | 364 | def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None: |
173 | 365 | """Stop a query execution. |
|
0 commit comments