Skip to content

Commit 2e039ec

Browse files
authored
Add CTAS bucketing to wr.athena.read_sql_query (#802)
1 parent 4edc97d commit 2e039ec

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

awswrangler/athena/_read.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import sys
88
import uuid
9-
from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Union
9+
from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Tuple, Union
1010

1111
import boto3
1212
import botocore.exceptions
@@ -385,18 +385,25 @@ def _resolve_query_without_cache_ctas(
385385
wg_config: _WorkGroupConfig,
386386
alt_database: Optional[str],
387387
name: Optional[str],
388+
ctas_bucketing_info: Optional[Tuple[List[str], int]],
388389
use_threads: bool,
389390
s3_additional_kwargs: Optional[Dict[str, Any]],
390391
boto3_session: boto3.Session,
391392
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
392393
path: str = f"{s3_output}/{name}"
393394
ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n"
394395
fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"'
396+
bucketing_str = (
397+
(f",\n" f" bucketed_by = ARRAY{ctas_bucketing_info[0]},\n" f" bucket_count = {ctas_bucketing_info[1]}")
398+
if ctas_bucketing_info
399+
else ""
400+
)
395401
sql = (
396402
f"CREATE TABLE {fully_qualified_name}\n"
397403
f"WITH(\n"
398404
f" format = 'Parquet',\n"
399405
f" parquet_compression = 'SNAPPY'"
406+
f"{bucketing_str}"
400407
f"{ext_location}"
401408
f") AS\n"
402409
f"{sql}"
@@ -521,6 +528,7 @@ def _resolve_query_without_cache(
521528
keep_files: bool,
522529
ctas_database_name: Optional[str],
523530
ctas_temp_table_name: Optional[str],
531+
ctas_bucketing_info: Optional[Tuple[List[str], int]],
524532
use_threads: bool,
525533
s3_additional_kwargs: Optional[Dict[str, Any]],
526534
boto3_session: boto3.Session,
@@ -553,6 +561,7 @@ def _resolve_query_without_cache(
553561
wg_config=wg_config,
554562
alt_database=ctas_database_name,
555563
name=name,
564+
ctas_bucketing_info=ctas_bucketing_info,
556565
use_threads=use_threads,
557566
s3_additional_kwargs=s3_additional_kwargs,
558567
boto3_session=boto3_session,
@@ -593,6 +602,7 @@ def read_sql_query(
593602
keep_files: bool = True,
594603
ctas_database_name: Optional[str] = None,
595604
ctas_temp_table_name: Optional[str] = None,
605+
ctas_bucketing_info: Optional[Tuple[List[str], int]] = None,
596606
use_threads: bool = True,
597607
boto3_session: Optional[boto3.Session] = None,
598608
max_cache_seconds: int = 0,
@@ -733,6 +743,10 @@ def read_sql_query(
733743
The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
734744
If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex()}"`.
735745
On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`.
746+
ctas_bucketing_info: Tuple[List[str], int], optional
747+
Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the
748+
second element.
749+
Only `str`, `int` and `bool` are supported as column data types for bucketing.
736750
use_threads : bool
737751
True to enable concurrent requests, False to disable multiple threads.
738752
If enabled os.cpu_count() will be used as the max number of threads.
@@ -841,6 +855,7 @@ def read_sql_query(
841855
keep_files=keep_files,
842856
ctas_database_name=ctas_database_name,
843857
ctas_temp_table_name=ctas_temp_table_name,
858+
ctas_bucketing_info=ctas_bucketing_info,
844859
use_threads=use_threads,
845860
s3_additional_kwargs=s3_additional_kwargs,
846861
boto3_session=session,
@@ -861,6 +876,7 @@ def read_sql_table(
861876
keep_files: bool = True,
862877
ctas_database_name: Optional[str] = None,
863878
ctas_temp_table_name: Optional[str] = None,
879+
ctas_bucketing_info: Optional[Tuple[List[str], int]] = None,
864880
use_threads: bool = True,
865881
boto3_session: Optional[boto3.Session] = None,
866882
max_cache_seconds: int = 0,
@@ -995,6 +1011,10 @@ def read_sql_table(
9951011
The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
9961012
If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex}"`.
9971013
On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`.
1014+
ctas_bucketing_info: Tuple[List[str], int], optional
1015+
Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the
1016+
second element.
1017+
Only `str`, `int` and `bool` are supported as column data types for bucketing.
9981018
use_threads : bool
9991019
True to enable concurrent requests, False to disable multiple threads.
10001020
If enabled os.cpu_count() will be used as the max number of threads.
@@ -1053,6 +1073,7 @@ def read_sql_table(
10531073
keep_files=keep_files,
10541074
ctas_database_name=ctas_database_name,
10551075
ctas_temp_table_name=ctas_temp_table_name,
1076+
ctas_bucketing_info=ctas_bucketing_info,
10561077
use_threads=use_threads,
10571078
boto3_session=boto3_session,
10581079
max_cache_seconds=max_cache_seconds,

tests/test_athena.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,33 @@ def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database,
121121
assert len(wr.s3.list_objects(path=path3)) == 0
122122

123123

124+
def test_athena_read_sql_ctas_bucketing(path, path2, glue_table, glue_table2, glue_database, glue_ctas_database):
125+
df = pd.DataFrame({"c0": [0, 1], "c1": ["foo", "bar"]})
126+
wr.s3.to_parquet(
127+
df=df,
128+
path=path,
129+
dataset=True,
130+
database=glue_database,
131+
table=glue_table,
132+
)
133+
df_ctas = wr.athena.read_sql_query(
134+
sql=f"SELECT * FROM {glue_table}",
135+
ctas_approach=True,
136+
database=glue_database,
137+
ctas_database_name=glue_ctas_database,
138+
ctas_temp_table_name=glue_table2,
139+
ctas_bucketing_info=(["c0"], 1),
140+
s3_output=path2,
141+
)
142+
df_no_ctas = wr.athena.read_sql_query(
143+
sql=f"SELECT * FROM {glue_table}",
144+
ctas_approach=False,
145+
database=glue_database,
146+
s3_output=path2,
147+
)
148+
assert df_ctas.equals(df_no_ctas)
149+
150+
124151
def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1):
125152
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
126153
wr.s3.to_parquet(

0 commit comments

Comments
 (0)