Skip to content

Commit 5e2f96e

Browse files
Merge branch 'main' into release-3.0.0
2 parents 665141e + ac82270 commit 5e2f96e

23 files changed

+187
-51
lines changed

awswrangler/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ def table_refs_to_df(tables: List[pa.Table], kwargs: Dict[str, Any]) -> pd.DataF
414414
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
415415

416416

417+
@engine.dispatch_on_engine
418+
def is_pandas_frame(obj: Any) -> bool:
419+
"""Checks if the passed objected is a Pandas DataFrame"""
420+
return isinstance(obj, pd.DataFrame)
421+
422+
417423
def list_to_arrow_table(
418424
mapping: List[Dict[str, Any]],
419425
schema: Optional[pa.Schema] = None,

awswrangler/athena/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
get_named_query_statement,
1010
get_query_columns_types,
1111
get_query_execution,
12+
get_query_executions,
1213
get_work_group,
14+
list_query_executions,
1315
repair_table,
1416
show_create_table,
1517
start_query_execution,
@@ -24,10 +26,12 @@
2426
"describe_table",
2527
"get_query_columns_types",
2628
"get_query_execution",
29+
"get_query_executions",
2730
"get_query_results",
2831
"get_named_query_statement",
2932
"get_work_group",
3033
"generate_create_query",
34+
"list_query_executions",
3135
"repair_table",
3236
"create_ctas_table",
3337
"show_create_table",

awswrangler/athena/_read.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,23 @@ def _fetch_parquet_result(
105105
if not paths:
106106
if not temp_table_fqn:
107107
raise exceptions.EmptyDataFrame("Query would return untyped, empty dataframe.")
108+
108109
database, temp_table_name = map(lambda x: x.replace('"', ""), temp_table_fqn.split("."))
109110
dtype_dict = catalog.get_table_types(database=database, table=temp_table_name, boto3_session=boto3_session)
110111
df = pd.DataFrame(columns=list(dtype_dict.keys()))
111112
df = cast_pandas_with_athena_types(df=df, dtype=dtype_dict)
112113
df = _apply_query_metadata(df=df, query_metadata=query_metadata)
114+
115+
if chunked:
116+
return (df,)
117+
113118
return df
119+
114120
if not pyarrow_additional_kwargs:
115121
pyarrow_additional_kwargs = {}
116122
if categories:
117123
pyarrow_additional_kwargs["categories"] = categories
124+
118125
ret = s3.read_parquet(
119126
path=paths,
120127
use_threads=use_threads,

awswrangler/athena/_utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,3 +1146,104 @@ def get_query_execution(query_execution_id: str, boto3_session: Optional[boto3.S
11461146
QueryExecutionId=query_execution_id,
11471147
)
11481148
return cast(Dict[str, Any], response["QueryExecution"])
1149+
1150+
1151+
def get_query_executions(
1152+
query_execution_ids: List[str], return_unprocessed: bool = False, boto3_session: Optional[boto3.Session] = None
1153+
) -> Union[Tuple[pd.DataFrame, pd.DataFrame], pd.DataFrame]:
1154+
"""From specified query execution IDs, return a DataFrame of query execution details.
1155+
1156+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html#Athena.Client.batch_get_query_execution
1157+
1158+
Parameters
1159+
----------
1160+
query_execution_ids : List[str]
1161+
Athena query execution IDs.
1162+
return_unprocessed: bool.
1163+
True to also return query executions id that are unable to be processed.
1164+
False to only return DataFrame of query execution details.
1165+
Default is False
1166+
boto3_session : boto3.Session(), optional
1167+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1168+
1169+
Returns
1170+
-------
1171+
DataFrame
1172+
DataFrame contain information about query execution details.
1173+
1174+
DataFrame
1175+
DataFrame contain information about unprocessed query execution ids.
1176+
1177+
Examples
1178+
--------
1179+
>>> import awswrangler as wr
1180+
>>> query_executions_df, unprocessed_query_executions_df = wr.athena.get_query_executions(
1181+
query_execution_ids=['query-execution-id','query-execution-id1']
1182+
)
1183+
"""
1184+
chunked_size: int = 50
1185+
query_executions: List[Dict[str, Any]] = []
1186+
unprocessed_query_execution: List[Dict[str, str]] = []
1187+
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
1188+
for i in range(0, len(query_execution_ids), chunked_size):
1189+
response = client_athena.batch_get_query_execution(QueryExecutionIds=query_execution_ids[i : i + chunked_size])
1190+
query_executions += response["QueryExecutions"]
1191+
unprocessed_query_execution += response["UnprocessedQueryExecutionIds"]
1192+
if unprocessed_query_execution and not return_unprocessed:
1193+
_logger.warning(
1194+
"Some of query execution ids are unable to be processed."
1195+
"Set return_unprocessed to True to get unprocessed query execution ids"
1196+
)
1197+
if return_unprocessed:
1198+
return pd.json_normalize(query_executions), pd.json_normalize(unprocessed_query_execution)
1199+
return pd.json_normalize(query_executions)
1200+
1201+
1202+
def list_query_executions(workgroup: Optional[str] = None, boto3_session: Optional[boto3.Session] = None) -> List[str]:
1203+
"""Fetch list query execution IDs ran in specified workgroup or primary work group if not specified.
1204+
1205+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/athena.html#Athena.Client.list_query_executions
1206+
1207+
Parameters
1208+
----------
1209+
workgroup : str
1210+
The name of the workgroup from which the query_id are being returned.
1211+
If not specified, a list of available query execution IDs for the queries in the primary workgroup is returned.
1212+
boto3_session : boto3.Session(), optional
1213+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1214+
1215+
Returns
1216+
-------
1217+
List[str]
1218+
List of query execution IDs.
1219+
1220+
Examples
1221+
--------
1222+
>>> import awswrangler as wr
1223+
>>> res = wr.athena.list_query_executions(workgroup='workgroup-name')
1224+
1225+
"""
1226+
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
1227+
kwargs: Dict[str, Any] = {"base": 1}
1228+
if workgroup:
1229+
kwargs["WorkGroup"] = workgroup
1230+
query_list: List[str] = []
1231+
response: Dict[str, Any] = _utils.try_it(
1232+
f=client_athena.list_query_executions,
1233+
ex=botocore.exceptions.ClientError,
1234+
ex_code="ThrottlingException",
1235+
max_num_tries=5,
1236+
**kwargs,
1237+
)
1238+
query_list += response["QueryExecutionIds"]
1239+
while "NextToken" in response:
1240+
kwargs["NextToken"] = response["NextToken"]
1241+
response = _utils.try_it(
1242+
f=client_athena.list_query_executions,
1243+
ex=botocore.exceptions.ClientError,
1244+
ex_code="ThrottlingException",
1245+
max_num_tries=5,
1246+
**kwargs,
1247+
)
1248+
query_list += response["QueryExecutionIds"]
1249+
return query_list

awswrangler/distributed/ray/_register.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# pylint: disable=import-outside-toplevel
33
from awswrangler._data_types import pyarrow_types_from_pandas
44
from awswrangler._distributed import MemoryFormatEnum, engine, memory_format
5-
from awswrangler._utils import table_refs_to_df
5+
from awswrangler._utils import is_pandas_frame, table_refs_to_df
66
from awswrangler.distributed.ray._core import ray_remote
77
from awswrangler.lakeformation._read import _get_work_unit_results
88
from awswrangler.s3._delete import _delete_objects
@@ -30,7 +30,7 @@ def register_ray() -> None:
3030
if memory_format.get() == MemoryFormatEnum.MODIN:
3131
from awswrangler.distributed.ray.modin._core import modin_repartition
3232
from awswrangler.distributed.ray.modin._data_types import pyarrow_types_from_pandas_distributed
33-
from awswrangler.distributed.ray.modin._utils import _arrow_refs_to_df
33+
from awswrangler.distributed.ray.modin._utils import _arrow_refs_to_df, _is_pandas_or_modin_frame
3434
from awswrangler.distributed.ray.modin.s3._read_parquet import _read_parquet_distributed
3535
from awswrangler.distributed.ray.modin.s3._read_text import _read_text_distributed
3636
from awswrangler.distributed.ray.modin.s3._write_dataset import (
@@ -52,5 +52,6 @@ def register_ray() -> None:
5252
to_json: modin_repartition(to_json),
5353
to_parquet: modin_repartition(to_parquet),
5454
table_refs_to_df: _arrow_refs_to_df,
55+
is_pandas_frame: _is_pandas_or_modin_frame,
5556
}.items():
5657
engine.register_func(o_f, d_f) # type: ignore

awswrangler/distributed/ray/modin/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Optional[Dic
4747
return _to_modin(dataset=ray.data.from_arrow_refs(arrow_refs), to_pandas_kwargs=kwargs)
4848

4949

50+
def _is_pandas_or_modin_frame(obj: Any) -> bool:
51+
return isinstance(obj, (pd.DataFrame, modin_pd.DataFrame))
52+
53+
5054
@dataclass
5155
class ParamConfig:
5256
"""

awswrangler/s3/_read_parquet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def read_parquet_table(
729729
partial_cast_function = functools.partial(
730730
_data_types.cast_pandas_with_athena_types, dtype=_extract_partitions_dtypes_from_table_details(response=res)
731731
)
732-
if isinstance(df, pd.DataFrame):
732+
if _utils.is_pandas_frame(df):
733733
return partial_cast_function(df)
734734
# df is a generator, so map is needed for casting dtypes
735735
return map(partial_cast_function, df)

docs/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,11 @@ Amazon Athena
119119
generate_create_query
120120
get_query_columns_types
121121
get_query_execution
122+
get_query_executions
122123
get_query_results
123124
get_named_query_statement
124125
get_work_group
126+
list_query_executions
125127
read_sql_query
126128
read_sql_table
127129
repair_table

tests/unit/test_athena.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,19 @@ def test_read_sql_query_wo_results(path, glue_database, glue_table):
679679
ensure_athena_query_metadata(df=df, ctas_approach=False, encrypted=False)
680680

681681

682+
@pytest.mark.parametrize("ctas_approach", [False, True])
683+
def test_read_sql_query_wo_results_chunked(path, glue_database, glue_table, ctas_approach):
684+
wr.catalog.create_parquet_table(database=glue_database, table=glue_table, path=path, columns_types={"c0": "int"})
685+
sql = f"SELECT * FROM {glue_database}.{glue_table}"
686+
687+
counter = 0
688+
for df in wr.athena.read_sql_query(sql, database=glue_database, ctas_approach=ctas_approach, chunksize=100):
689+
assert df.empty
690+
counter += 1
691+
692+
assert counter == 1
693+
694+
682695
@pytest.mark.xfail()
683696
def test_read_sql_query_wo_results_ctas(path, glue_database, glue_table):
684697
wr.catalog.create_parquet_table(database=glue_database, table=glue_table, path=path, columns_types={"c0": "int"})
@@ -1304,3 +1317,22 @@ def test_athena_generate_create_query(path, glue_database, glue_table):
13041317
)
13051318
wr.athena.start_query_execution(sql=query, database=glue_database, wait=True)
13061319
assert query == wr.athena.generate_create_query(database=glue_database, table=glue_table)
1320+
1321+
1322+
def test_get_query_execution(workgroup0, workgroup1):
1323+
query_execution_ids = wr.athena.list_query_executions(workgroup=workgroup0) + wr.athena.list_query_executions(
1324+
workgroup=workgroup1
1325+
)
1326+
assert query_execution_ids
1327+
query_execution_detail = wr.athena.get_query_execution(query_execution_id=query_execution_ids[0])
1328+
query_executions_df = wr.athena.get_query_executions(query_execution_ids)
1329+
assert isinstance(query_executions_df, pd.DataFrame)
1330+
assert isinstance(query_execution_detail, dict)
1331+
assert set(query_execution_ids).intersection(set(query_executions_df["QueryExecutionId"].values.tolist()))
1332+
query_execution_ids1 = query_execution_ids + ["aaa", "bbb"]
1333+
query_executions_df, unprocessed_query_executions_df = wr.athena.get_query_executions(
1334+
query_execution_ids1, return_unprocessed=True
1335+
)
1336+
assert isinstance(unprocessed_query_executions_df, pd.DataFrame)
1337+
assert set(query_execution_ids).intersection(set(query_executions_df["QueryExecutionId"].values.tolist()))
1338+
assert {"aaa", "bbb"}.intersection(set(unprocessed_query_executions_df["QueryExecutionId"].values.tolist()))

tutorials/006 - Amazon Athena.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
" mode=\"overwrite\",\n",
144144
" database=\"awswrangler_test\",\n",
145145
" table=\"noaa\"\n",
146-
");"
146+
")"
147147
]
148148
},
149149
{

0 commit comments

Comments
 (0)