Skip to content

Commit 0545dc8

Browse files
authored
Add get_query_results to the Athena module (#1499)
* Add get_query_results function to Athena module * Add test_get_query_results to test_athena
1 parent c0d1c97 commit 0545dc8

File tree

3 files changed

+138
-1
lines changed

3 files changed

+138
-1
lines changed

awswrangler/athena/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Amazon Athena Module."""
22

3-
from awswrangler.athena._read import read_sql_query, read_sql_table, unload # noqa
3+
from awswrangler.athena._read import get_query_results, read_sql_query, read_sql_table, unload # noqa
44
from awswrangler.athena._utils import ( # noqa
55
create_athena_bucket,
66
create_ctas_table,
@@ -23,6 +23,7 @@
2323
"describe_table",
2424
"get_query_columns_types",
2525
"get_query_execution",
26+
"get_query_results",
2627
"get_named_query_statement",
2728
"get_work_group",
2829
"repair_table",

awswrangler/athena/_read.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,96 @@ def _unload(
559559
return query_metadata
560560

561561

562+
@apply_configs
563+
def get_query_results(
564+
query_execution_id: str,
565+
use_threads: Union[bool, int] = True,
566+
boto3_session: Optional[boto3.Session] = None,
567+
categories: Optional[List[str]] = None,
568+
chunksize: Optional[Union[int, bool]] = None,
569+
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
570+
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
571+
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
572+
"""Get AWS Athena SQL query results as a Pandas DataFrame.
573+
574+
Parameters
575+
----------
576+
query_execution_id : str
577+
SQL query's execution_id on AWS Athena.
578+
use_threads : bool, int
579+
True to enable concurrent requests, False to disable multiple threads.
580+
If enabled os.cpu_count() will be used as the max number of threads.
581+
If integer is provided, specified number is used.
582+
boto3_session : boto3.Session(), optional
583+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
584+
categories: List[str], optional
585+
List of columns names that should be returned as pandas.Categorical.
586+
Recommended for memory restricted environments.
587+
chunksize : Union[int, bool], optional
588+
If passed will split the data in a Iterable of DataFrames (Memory friendly).
589+
If `True` wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize.
590+
If an `INTEGER` is passed Wrangler will iterate on the data by number of rows igual the received INTEGER.
591+
s3_additional_kwargs : Optional[Dict[str, Any]]
592+
Forwarded to botocore requests.
593+
e.g. s3_additional_kwargs={'RequestPayer': 'requester'}
594+
pyarrow_additional_kwargs : Optional[Dict[str, Any]]
595+
Forward to the ParquetFile class or converting an Arrow table to Pandas, currently only an
596+
"coerce_int96_timestamp_unit" or "timestamp_as_object" argument will be considered. If reading parquet
597+
files where you cannot convert a timestamp to pandas Timestamp[ns] consider setting timestamp_as_object=True,
598+
to allow for timestamp units larger than "ns". If reading parquet data that still uses INT96 (like Athena
599+
outputs) you can use coerce_int96_timestamp_unit to specify what timestamp unit to encode INT96 to (by default
600+
this is "ns", if you know the output parquet came from a system that encodes timestamp to a particular unit
601+
then set this to that same unit e.g. coerce_int96_timestamp_unit="ms").
602+
603+
Returns
604+
-------
605+
Union[pd.DataFrame, Iterator[pd.DataFrame]]
606+
Pandas DataFrame or Generator of Pandas DataFrames if chunksize is passed.
607+
608+
Examples
609+
--------
610+
>>> import awswrangler as wr
611+
>>> res = wr.athena.get_query_results(
612+
... query_execution_id="cbae5b41-8103-4709-95bb-887f88edd4f2"
613+
... )
614+
615+
"""
616+
query_metadata: _QueryMetadata = _get_query_metadata(
617+
query_execution_id=query_execution_id,
618+
boto3_session=boto3_session,
619+
categories=categories,
620+
metadata_cache_manager=_cache_manager,
621+
)
622+
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
623+
query_info: Dict[str, Any] = client_athena.get_query_execution(QueryExecutionId=query_execution_id)[
624+
"QueryExecution"
625+
]
626+
statement_type: Optional[str] = query_info.get("StatementType")
627+
if (statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE")) or (
628+
statement_type == "DML" and query_info["Query"].startswith("UNLOAD")
629+
):
630+
return _fetch_parquet_result(
631+
query_metadata=query_metadata,
632+
keep_files=True,
633+
categories=categories,
634+
chunksize=chunksize,
635+
use_threads=use_threads,
636+
boto3_session=boto3_session,
637+
s3_additional_kwargs=s3_additional_kwargs,
638+
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
639+
)
640+
if statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
641+
return _fetch_csv_result(
642+
query_metadata=query_metadata,
643+
keep_files=True,
644+
chunksize=chunksize,
645+
use_threads=use_threads,
646+
boto3_session=boto3_session,
647+
s3_additional_kwargs=s3_additional_kwargs,
648+
)
649+
raise exceptions.UndetectedType(f"""Unable to get results for: {query_info["Query"]}.""")
650+
651+
562652
@apply_configs
563653
def read_sql_query(
564654
sql: str,

tests/test_athena.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,3 +1097,49 @@ def test_start_query_execution_wait(path, glue_database, glue_table):
10971097
assert query_execution_result["Query"] == sql
10981098
assert query_execution_result["StatementType"] == "DML"
10991099
assert query_execution_result["QueryExecutionContext"]["Database"] == glue_database
1100+
1101+
1102+
def test_get_query_results(path, glue_table, glue_database):
1103+
1104+
sql = (
1105+
"SELECT CAST("
1106+
" ROW(1, ROW(2, ROW(3, '4'))) AS"
1107+
" ROW(field0 BIGINT, field1 ROW(field2 BIGINT, field3 ROW(field4 BIGINT, field5 VARCHAR)))"
1108+
") AS col0"
1109+
)
1110+
1111+
df_ctas: pd.DataFrame = wr.athena.read_sql_query(
1112+
sql=sql, database=glue_database, ctas_approach=True, unload_approach=False
1113+
)
1114+
query_id_ctas = df_ctas.query_metadata["QueryExecutionId"]
1115+
df_get_query_results_ctas = wr.athena.get_query_results(query_execution_id=query_id_ctas)
1116+
pd.testing.assert_frame_equal(df_get_query_results_ctas, df_ctas)
1117+
1118+
df_unload: pd.DataFrame = wr.athena.read_sql_query(
1119+
sql=sql, database=glue_database, ctas_approach=False, unload_approach=True, s3_output=path
1120+
)
1121+
query_id_unload = df_unload.query_metadata["QueryExecutionId"]
1122+
df_get_query_results_df_unload = wr.athena.get_query_results(query_execution_id=query_id_unload)
1123+
pd.testing.assert_frame_equal(df_get_query_results_df_unload, df_unload)
1124+
1125+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
1126+
wr.s3.to_parquet(
1127+
df=get_df(),
1128+
path=path,
1129+
index=True,
1130+
use_threads=True,
1131+
dataset=True,
1132+
mode="overwrite",
1133+
database=glue_database,
1134+
table=glue_table,
1135+
partition_cols=["par0", "par1"],
1136+
)
1137+
1138+
reg_sql = f"SELECT * FROM {glue_table}"
1139+
1140+
df_regular: pd.DataFrame = wr.athena.read_sql_query(
1141+
sql=reg_sql, database=glue_database, ctas_approach=False, unload_approach=False
1142+
)
1143+
query_id_regular = df_regular.query_metadata["QueryExecutionId"]
1144+
df_get_query_results_df_regular = wr.athena.get_query_results(query_execution_id=query_id_regular)
1145+
pd.testing.assert_frame_equal(df_get_query_results_df_regular, df_regular)

0 commit comments

Comments
 (0)