Skip to content

Commit d01824d

Browse files
committed
chore(ci): fix static checks to pass CI
1 parent 0e7345d commit d01824d

File tree

3 files changed

+129
-2
lines changed

3 files changed

+129
-2
lines changed

awswrangler/athena/_executions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
from __future__ import annotations
44

55
import logging
6+
import os
67
import time
8+
from concurrent.futures import ThreadPoolExecutor
79
from typing import (
810
Any,
911
Dict,
1012
cast,
1113
)
1214

13-
import os
1415
import boto3
1516
import botocore
1617
from typing_extensions import Literal
1718

18-
from concurrent.futures import ThreadPoolExecutor
1919
from awswrangler import _utils, exceptions, typing
2020
from awswrangler._config import apply_configs
2121

awswrangler/athena/_executions.pyi

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,71 @@ def start_query_execution(
5858
data_source: str | None = ...,
5959
wait: bool,
6060
) -> str | dict[str, Any]: ...
61+
@overload
62+
def start_query_executions(
63+
sqls: list[str],
64+
database: str | None = ...,
65+
s3_output: str | None = ...,
66+
workgroup: str = ...,
67+
encryption: str | None = ...,
68+
kms_key: str | None = ...,
69+
params: dict[str, Any] | list[str] | None = ...,
70+
paramstyle: Literal["qmark", "named"] = ...,
71+
boto3_session: boto3.Session | None = ...,
72+
client_request_token: str | list[str] | None = ...,
73+
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
74+
athena_query_wait_polling_delay: float = ...,
75+
data_source: str | None = ...,
76+
wait: Literal[False] = ...,
77+
check_workgroup: bool = ...,
78+
enforce_workgroup: bool = ...,
79+
as_iterator: bool = ...,
80+
use_threads: bool | int = ...,
81+
) -> list[str]: ...
82+
@overload
83+
def start_query_executions(
84+
sqls: list[str],
85+
*,
86+
database: str | None = ...,
87+
s3_output: str | None = ...,
88+
workgroup: str = ...,
89+
encryption: str | None = ...,
90+
kms_key: str | None = ...,
91+
params: dict[str, Any] | list[str] | None = ...,
92+
paramstyle: Literal["qmark", "named"] = ...,
93+
boto3_session: boto3.Session | None = ...,
94+
client_request_token: str | list[str] | None = ...,
95+
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
96+
athena_query_wait_polling_delay: float = ...,
97+
data_source: str | None = ...,
98+
wait: Literal[True],
99+
check_workgroup: bool = ...,
100+
enforce_workgroup: bool = ...,
101+
as_iterator: bool = ...,
102+
use_threads: bool | int = ...,
103+
) -> list[dict[str, Any]]: ...
104+
@overload
105+
def start_query_executions(
106+
sqls: list[str],
107+
*,
108+
database: str | None = ...,
109+
s3_output: str | None = ...,
110+
workgroup: str = ...,
111+
encryption: str | None = ...,
112+
kms_key: str | None = ...,
113+
params: dict[str, Any] | list[str] | None = ...,
114+
paramstyle: Literal["qmark", "named"] = ...,
115+
boto3_session: boto3.Session | None = ...,
116+
client_request_token: str | list[str] | None = ...,
117+
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
118+
athena_query_wait_polling_delay: float = ...,
119+
data_source: str | None = ...,
120+
wait: bool,
121+
check_workgroup: bool = ...,
122+
enforce_workgroup: bool = ...,
123+
as_iterator: bool = ...,
124+
use_threads: bool | int = ...,
125+
) -> list[str] | list[dict[str, Any]]: ...
61126
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = ...) -> None: ...
62127
def wait_query(
63128
query_execution_id: str,

tests/unit/test_athena.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,3 +1708,65 @@ def test_athena_date_recovery(path, glue_database, glue_table):
17081708
ctas_approach=False,
17091709
)
17101710
assert pandas_equals(df, df2)
1711+
1712+
def test_start_query_executions_ids_and_results(path, glue_database, glue_table):
1713+
# Prepare table
1714+
wr.s3.to_parquet(
1715+
df=get_df(),
1716+
path=path,
1717+
index=True,
1718+
dataset=True,
1719+
mode="overwrite",
1720+
database=glue_database,
1721+
table=glue_table,
1722+
partition_cols=["par0", "par1"],
1723+
)
1724+
1725+
sqls = [
1726+
f"SELECT * FROM {glue_table} LIMIT 1",
1727+
f"SELECT COUNT(*) FROM {glue_table}",
1728+
]
1729+
1730+
# Case 1: Sequential, return query IDs
1731+
qids = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, use_threads=False)
1732+
assert isinstance(qids, list)
1733+
assert all(isinstance(qid, str) for qid in qids)
1734+
assert len(qids) == len(sqls)
1735+
1736+
# Case 2: Sequential, wait for results
1737+
results = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=False)
1738+
assert isinstance(results, list)
1739+
assert all(isinstance(r, dict) for r in results)
1740+
assert all("Status" in r for r in results)
1741+
1742+
# Case 3: Parallel execution with threads
1743+
results_parallel = wr.athena.start_query_executions(
1744+
sqls=sqls, database=glue_database, wait=True, use_threads=True
1745+
)
1746+
assert isinstance(results_parallel, list)
1747+
assert all(isinstance(r, dict) for r in results_parallel)
1748+
1749+
1750+
def test_start_query_executions_as_iterator(path, glue_database, glue_table):
1751+
# Prepare table
1752+
wr.s3.to_parquet(
1753+
df=get_df(),
1754+
path=path,
1755+
index=True,
1756+
dataset=True,
1757+
mode="overwrite",
1758+
database=glue_database,
1759+
table=glue_table,
1760+
partition_cols=["par0", "par1"],
1761+
)
1762+
1763+
sqls = [f"SELECT * FROM {glue_table} LIMIT 1"]
1764+
1765+
# Case: as_iterator=True should return a generator-like object
1766+
qids_iter = wr.athena.start_query_executions(
1767+
sqls=sqls, database=glue_database, wait=False, as_iterator=True
1768+
)
1769+
assert not isinstance(qids_iter, list)
1770+
qids = list(qids_iter)
1771+
assert len(qids) == 1
1772+
assert isinstance(qids[0], str)

0 commit comments

Comments
 (0)