Skip to content

Commit 5e1c7f0

Browse files
committed
chore: cleanup and CI adjustments
- Removed unused `reduce` import from Athena module. - Applied ruff formatting to `start_query_executions`. - Fixed static check issues to pass CI. - Added ruff check on Athena tests file.
1 parent b6e4d88 commit 5e1c7f0

File tree

3 files changed

+128
-3
lines changed

3 files changed

+128
-3
lines changed

awswrangler/athena/_executions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,21 @@
33
from __future__ import annotations
44

55
import logging
6+
import os
67
import time
8+
from concurrent.futures import ThreadPoolExecutor
79
from typing import (
810
Any,
911
Dict,
1012
cast,
1113
)
1214

13-
import os
1415
import boto3
1516
import botocore
1617
from typing_extensions import Literal
1718

18-
from concurrent.futures import ThreadPoolExecutor
1919
from awswrangler import _utils, exceptions, typing
2020
from awswrangler._config import apply_configs
21-
from functools import reduce
2221

2322
from ._cache import _CacheInfo, _check_for_cached_results
2423
from ._utils import (
@@ -34,6 +33,7 @@
3433

3534
_DEFAULT_MAX_WORKERS = max(4, os.cpu_count() or 4)
3635

36+
3737
@apply_configs
3838
def start_query_execution(
3939
sql: str,
@@ -172,6 +172,7 @@ def start_query_execution(
172172

173173
return query_execution_id
174174

175+
175176
@apply_configs
176177
def start_query_executions(
177178
sqls: list[str],

awswrangler/athena/_executions.pyi

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,71 @@ def start_query_execution(
5858
data_source: str | None = ...,
5959
wait: bool,
6060
) -> str | dict[str, Any]: ...
61+
@overload
62+
def start_query_executions(
63+
sqls: list[str],
64+
database: str | None = ...,
65+
s3_output: str | None = ...,
66+
workgroup: str = ...,
67+
encryption: str | None = ...,
68+
kms_key: str | None = ...,
69+
params: dict[str, Any] | list[str] | None = ...,
70+
paramstyle: Literal["qmark", "named"] = ...,
71+
boto3_session: boto3.Session | None = ...,
72+
client_request_token: str | list[str] | None = ...,
73+
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
74+
athena_query_wait_polling_delay: float = ...,
75+
data_source: str | None = ...,
76+
wait: Literal[False] = ...,
77+
check_workgroup: bool = ...,
78+
enforce_workgroup: bool = ...,
79+
as_iterator: bool = ...,
80+
use_threads: bool | int = ...,
81+
) -> list[str]: ...
82+
@overload
83+
def start_query_executions(
84+
sqls: list[str],
85+
*,
86+
database: str | None = ...,
87+
s3_output: str | None = ...,
88+
workgroup: str = ...,
89+
encryption: str | None = ...,
90+
kms_key: str | None = ...,
91+
params: dict[str, Any] | list[str] | None = ...,
92+
paramstyle: Literal["qmark", "named"] = ...,
93+
boto3_session: boto3.Session | None = ...,
94+
client_request_token: str | list[str] | None = ...,
95+
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
96+
athena_query_wait_polling_delay: float = ...,
97+
data_source: str | None = ...,
98+
wait: Literal[True],
99+
check_workgroup: bool = ...,
100+
enforce_workgroup: bool = ...,
101+
as_iterator: bool = ...,
102+
use_threads: bool | int = ...,
103+
) -> list[dict[str, Any]]: ...
104+
@overload
105+
def start_query_executions(
106+
sqls: list[str],
107+
*,
108+
database: str | None = ...,
109+
s3_output: str | None = ...,
110+
workgroup: str = ...,
111+
encryption: str | None = ...,
112+
kms_key: str | None = ...,
113+
params: dict[str, Any] | list[str] | None = ...,
114+
paramstyle: Literal["qmark", "named"] = ...,
115+
boto3_session: boto3.Session | None = ...,
116+
client_request_token: str | list[str] | None = ...,
117+
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
118+
athena_query_wait_polling_delay: float = ...,
119+
data_source: str | None = ...,
120+
wait: bool,
121+
check_workgroup: bool = ...,
122+
enforce_workgroup: bool = ...,
123+
as_iterator: bool = ...,
124+
use_threads: bool | int = ...,
125+
) -> list[str] | list[dict[str, Any]]: ...
61126
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = ...) -> None: ...
62127
def wait_query(
63128
query_execution_id: str,

tests/unit/test_athena.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,3 +1708,62 @@ def test_athena_date_recovery(path, glue_database, glue_table):
17081708
ctas_approach=False,
17091709
)
17101710
assert pandas_equals(df, df2)
1711+
1712+
1713+
def test_start_query_executions_ids_and_results(path, glue_database, glue_table):
1714+
# Prepare table
1715+
wr.s3.to_parquet(
1716+
df=get_df(),
1717+
path=path,
1718+
index=True,
1719+
dataset=True,
1720+
mode="overwrite",
1721+
database=glue_database,
1722+
table=glue_table,
1723+
partition_cols=["par0", "par1"],
1724+
)
1725+
1726+
sqls = [
1727+
f"SELECT * FROM {glue_table} LIMIT 1",
1728+
f"SELECT COUNT(*) FROM {glue_table}",
1729+
]
1730+
1731+
# Case 1: Sequential, return query IDs
1732+
qids = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, use_threads=False)
1733+
assert isinstance(qids, list)
1734+
assert all(isinstance(qid, str) for qid in qids)
1735+
assert len(qids) == len(sqls)
1736+
1737+
# Case 2: Sequential, wait for results
1738+
results = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=False)
1739+
assert isinstance(results, list)
1740+
assert all(isinstance(r, dict) for r in results)
1741+
assert all("Status" in r for r in results)
1742+
1743+
# Case 3: Parallel execution with threads
1744+
results_parallel = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=True)
1745+
assert isinstance(results_parallel, list)
1746+
assert all(isinstance(r, dict) for r in results_parallel)
1747+
1748+
1749+
def test_start_query_executions_as_iterator(path, glue_database, glue_table):
1750+
# Prepare table
1751+
wr.s3.to_parquet(
1752+
df=get_df(),
1753+
path=path,
1754+
index=True,
1755+
dataset=True,
1756+
mode="overwrite",
1757+
database=glue_database,
1758+
table=glue_table,
1759+
partition_cols=["par0", "par1"],
1760+
)
1761+
1762+
sqls = [f"SELECT * FROM {glue_table} LIMIT 1"]
1763+
1764+
# Case: as_iterator=True should return a generator-like object
1765+
qids_iter = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, as_iterator=True)
1766+
assert not isinstance(qids_iter, list)
1767+
qids = list(qids_iter)
1768+
assert len(qids) == 1
1769+
assert isinstance(qids[0], str)

0 commit comments

Comments
 (0)