Skip to content

Commit c3c2694

Browse files
Add ctas_write_compression argument to athena.read_sql_query (#1795)
1 parent 81f4d5b commit c3c2694

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

awswrangler/athena/_read.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def _resolve_query_without_cache_ctas(
261261
alt_database: Optional[str],
262262
name: Optional[str],
263263
ctas_bucketing_info: Optional[Tuple[List[str], int]],
264+
ctas_write_compression: Optional[str],
264265
use_threads: Union[bool, int],
265266
s3_additional_kwargs: Optional[Dict[str, Any]],
266267
boto3_session: boto3.Session,
@@ -276,6 +277,7 @@ def _resolve_query_without_cache_ctas(
276277
s3_output=s3_output,
277278
workgroup=workgroup,
278279
encryption=encryption,
280+
write_compression=ctas_write_compression,
279281
kms_key=kms_key,
280282
wait=True,
281283
boto3_session=boto3_session,
@@ -409,6 +411,7 @@ def _resolve_query_without_cache(
409411
ctas_database_name: Optional[str],
410412
ctas_temp_table_name: Optional[str],
411413
ctas_bucketing_info: Optional[Tuple[List[str], int]],
414+
ctas_write_compression: Optional[str],
412415
use_threads: Union[bool, int],
413416
s3_additional_kwargs: Optional[Dict[str, Any]],
414417
boto3_session: boto3.Session,
@@ -439,6 +442,7 @@ def _resolve_query_without_cache(
439442
alt_database=ctas_database_name,
440443
name=name,
441444
ctas_bucketing_info=ctas_bucketing_info,
445+
ctas_write_compression=ctas_write_compression,
442446
use_threads=use_threads,
443447
s3_additional_kwargs=s3_additional_kwargs,
444448
boto3_session=boto3_session,
@@ -656,7 +660,7 @@ def get_query_results(
656660

657661

658662
@apply_configs
659-
def read_sql_query(
663+
def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
660664
sql: str,
661665
database: str,
662666
ctas_approach: bool = True,
@@ -672,6 +676,7 @@ def read_sql_query(
672676
ctas_database_name: Optional[str] = None,
673677
ctas_temp_table_name: Optional[str] = None,
674678
ctas_bucketing_info: Optional[Tuple[List[str], int]] = None,
679+
ctas_write_compression: Optional[str] = None,
675680
use_threads: Union[bool, int] = True,
676681
boto3_session: Optional[boto3.Session] = None,
677682
max_cache_seconds: int = 0,
@@ -838,6 +843,9 @@ def read_sql_query(
838843
Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the
839844
second element.
840845
Only `str`, `int` and `bool` are supported as column data types for bucketing.
846+
ctas_write_compression: str, optional
847+
Write compression for the temporary table where the CTAS result is stored.
848+
Corresponds to the `write_compression` parameters for CREATE TABLE AS statement in Athena.
841849
use_threads : bool, int
842850
True to enable concurrent requests, False to disable multiple threads.
843851
If enabled os.cpu_count() will be used as the max number of threads.
@@ -963,6 +971,7 @@ def read_sql_query(
963971
ctas_database_name=ctas_database_name,
964972
ctas_temp_table_name=ctas_temp_table_name,
965973
ctas_bucketing_info=ctas_bucketing_info,
974+
ctas_write_compression=ctas_write_compression,
966975
use_threads=use_threads,
967976
s3_additional_kwargs=s3_additional_kwargs,
968977
boto3_session=session,
@@ -987,6 +996,7 @@ def read_sql_table(
987996
ctas_database_name: Optional[str] = None,
988997
ctas_temp_table_name: Optional[str] = None,
989998
ctas_bucketing_info: Optional[Tuple[List[str], int]] = None,
999+
ctas_write_compression: Optional[str] = None,
9901000
use_threads: Union[bool, int] = True,
9911001
boto3_session: Optional[boto3.Session] = None,
9921002
max_cache_seconds: int = 0,
@@ -1131,6 +1141,9 @@ def read_sql_table(
11311141
Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the
11321142
second element.
11331143
Only `str`, `int` and `bool` are supported as column data types for bucketing.
1144+
ctas_write_compression: str, optional
1145+
Write compression for the temporary table where the CTAS result is stored.
1146+
Corresponds to the `write_compression` parameters for CREATE TABLE AS statement in Athena.
11341147
use_threads : bool, int
11351148
True to enable concurrent requests, False to disable multiple threads.
11361149
If enabled os.cpu_count() will be used as the max number of threads.
@@ -1202,6 +1215,7 @@ def read_sql_table(
12021215
ctas_database_name=ctas_database_name,
12031216
ctas_temp_table_name=ctas_temp_table_name,
12041217
ctas_bucketing_info=ctas_bucketing_info,
1218+
ctas_write_compression=ctas_write_compression,
12051219
use_threads=use_threads,
12061220
boto3_session=boto3_session,
12071221
max_cache_seconds=max_cache_seconds,

awswrangler/athena/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,11 @@ def start_query_execution(
472472
max_cache_query_inspections=max_cache_query_inspections,
473473
max_remote_cache_entries=max_remote_cache_entries,
474474
)
475+
_logger.debug("cache_info:\n%s", cache_info)
475476

476477
if cache_info.has_valid_cache and cache_info.query_execution_id is not None:
477478
query_execution_id = cache_info.query_execution_id
479+
_logger.debug("Valid cache found. Retrieving...")
478480
else:
479481
wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup)
480482
query_execution_id = _start_query_execution(

tests/test_athena.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import logging
33
import string
4+
from unittest.mock import patch
45

56
import boto3
67
import numpy as np
@@ -1252,3 +1253,34 @@ def test_get_query_execution(workgroup0, workgroup1):
12521253
assert isinstance(unprocessed_query_executions_df, pd.DataFrame)
12531254
assert set(query_execution_ids).intersection(set(query_executions_df["QueryExecutionId"].values.tolist()))
12541255
assert {"aaa", "bbb"}.intersection(set(unprocessed_query_executions_df["QueryExecutionId"].values.tolist()))
1256+
1257+
1258+
@pytest.mark.parametrize("compression", [None, "snappy", "gzip"])
1259+
def test_read_sql_query_ctas_write_compression(path, glue_database, glue_table, compression):
1260+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
1261+
wr.s3.to_parquet(
1262+
df=get_df(),
1263+
path=path,
1264+
index=True,
1265+
use_threads=True,
1266+
dataset=True,
1267+
mode="overwrite",
1268+
database=glue_database,
1269+
table=glue_table,
1270+
partition_cols=["par0", "par1"],
1271+
)
1272+
1273+
with patch(
1274+
"awswrangler.athena._read.create_ctas_table", wraps=wr.athena.create_ctas_table
1275+
) as mock_create_ctas_table:
1276+
wr.athena.read_sql_query(
1277+
sql=f"SELECT * FROM {glue_table}",
1278+
database=glue_database,
1279+
ctas_approach=True,
1280+
ctas_write_compression=compression,
1281+
)
1282+
1283+
mock_create_ctas_table.assert_called_once()
1284+
1285+
create_ctas_table_args = mock_create_ctas_table.call_args.kwargs
1286+
create_ctas_table_args["compression"] = compression

0 commit comments

Comments
 (0)