Skip to content

Commit 785265c

Browse files
feat: Add max_results to athena.list_query_executions (#2665)
1 parent 36906a7 commit 785265c

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

awswrangler/athena/_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,17 +1114,24 @@ def get_query_executions(
11141114
return pd.json_normalize(query_executions)
11151115

11161116

1117-
def list_query_executions(workgroup: str | None = None, boto3_session: boto3.Session | None = None) -> list[str]:
1117+
def list_query_executions(
1118+
workgroup: str | None = None,
1119+
max_results: int | None = None,
1120+
boto3_session: boto3.Session | None = None,
1121+
) -> list[str]:
11181122
"""Fetch list query execution IDs ran in specified workgroup or primary work group if not specified.
11191123
11201124
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html#Athena.Client.list_query_executions
11211125
11221126
Parameters
11231127
----------
1124-
workgroup : str
1128+
workgroup: str
11251129
The name of the workgroup from which the query_id are being returned.
11261130
If not specified, a list of available query execution IDs for the queries in the primary workgroup is returned.
1127-
boto3_session : boto3.Session(), optional
1131+
max_results: int, optional
1132+
The maximum number of query execution IDs to return in this request.
1133+
If not present, all execution IDs will be returned.
1134+
boto3_session: boto3.Session(), optional
11281135
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
11291136
11301137
Returns
@@ -1139,9 +1146,14 @@ def list_query_executions(workgroup: str | None = None, boto3_session: boto3.Ses
11391146
11401147
"""
11411148
client_athena = _utils.client(service_name="athena", session=boto3_session)
1142-
kwargs: dict[str, Any] = {"base": 1}
1149+
1150+
kwargs: dict[str, Any] = {}
11431151
if workgroup:
11441152
kwargs["WorkGroup"] = workgroup
1153+
1154+
if max_results is not None:
1155+
kwargs["MaxResults"] = min(max_results, 50)
1156+
11451157
query_list: list[str] = []
11461158
response = _utils.try_it(
11471159
f=client_athena.list_query_executions,
@@ -1151,8 +1163,16 @@ def list_query_executions(workgroup: str | None = None, boto3_session: boto3.Ses
11511163
**kwargs,
11521164
)
11531165
query_list += response["QueryExecutionIds"]
1166+
11541167
while "NextToken" in response:
11551168
kwargs["NextToken"] = response["NextToken"]
1169+
1170+
if max_results is not None:
1171+
if len(query_list) >= max_results:
1172+
break
1173+
1174+
kwargs["MaxResults"] = min(max_results - len(query_list), 50)
1175+
11561176
response = _utils.try_it(
11571177
f=client_athena.list_query_executions,
11581178
ex=botocore.exceptions.ClientError,
@@ -1161,5 +1181,6 @@ def list_query_executions(workgroup: str | None = None, boto3_session: boto3.Ses
11611181
**kwargs,
11621182
)
11631183
query_list += response["QueryExecutionIds"]
1184+
11641185
_logger.debug("Running %d query executions", len(query_list))
11651186
return query_list

tests/unit/test_athena.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,18 @@ def test_get_query_execution(workgroup0, workgroup1):
14291429
assert {"aaa", "bbb"}.intersection(set(unprocessed_query_executions_df["QueryExecutionId"].values.tolist()))
14301430

14311431

1432+
@pytest.mark.parametrize("max_results", [55, 3])
1433+
def test_list_query_executions_max_results(workgroup0: str, max_results: int):
1434+
for _ in range(max_results + 1):
1435+
wr.athena.start_query_execution(sql="SELECT random(10)", workgroup=workgroup0, wait=False)
1436+
1437+
query_execution_ids = wr.athena.list_query_executions(workgroup=workgroup0)
1438+
assert len(query_execution_ids) > max_results
1439+
1440+
query_execution_ids_max_results = wr.athena.list_query_executions(workgroup=workgroup0, max_results=max_results)
1441+
assert len(query_execution_ids_max_results) == max_results
1442+
1443+
14321444
@pytest.mark.parametrize("compression", [None, "snappy", "gzip"])
14331445
def test_read_sql_query_ctas_write_compression(path, glue_database, glue_table, compression):
14341446
wr.s3.to_parquet(

0 commit comments

Comments
 (0)