Skip to content

Commit 728c5ec

Browse files
authored
Add create_ctas_table to Athena module (#1207)
* Add create_ctas_table to Athena module * Fix is_parquet_format regex * Add Athena test create_ctas_table
1 parent 1364615 commit 728c5ec

File tree

7 files changed

+221
-48
lines changed

7 files changed

+221
-48
lines changed

awswrangler/athena/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from awswrangler.athena._read import read_sql_query, read_sql_table, unload # noqa
44
from awswrangler.athena._utils import ( # noqa
55
create_athena_bucket,
6+
create_ctas_table,
67
describe_table,
78
get_named_query_statement,
89
get_query_columns_types,
@@ -25,6 +26,7 @@
2526
"get_named_query_statement",
2627
"get_work_group",
2728
"repair_table",
29+
"create_ctas_table",
2830
"show_create_table",
2931
"start_query_execution",
3032
"stop_query_execution",

awswrangler/athena/_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def max_cache_size(self, value: int) -> None:
8787
def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
8888
"""Check if `possible_ctas` is a valid parquet-generating CTAS and returns the full SELECT statement."""
8989
possible_ctas = possible_ctas.lower()
90-
parquet_format_regex: str = r"format\s*=\s*\'parquet\'\s*,"
90+
parquet_format_regex: str = r"format\s*=\s*\'parquet\'\s*"
9191
is_parquet_format: Optional[Match[str]] = re.search(pattern=parquet_format_regex, string=possible_ctas)
9292
if is_parquet_format is not None:
9393
unstripped_select_statement_regex: str = r"\s+as\s+\(*(select|with).*"

awswrangler/athena/_read.py

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_QueryMetadata,
2323
_start_query_execution,
2424
_WorkGroupConfig,
25+
create_ctas_table,
2526
)
2627

2728
from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results
@@ -251,7 +252,6 @@ def _resolve_query_without_cache_ctas(
251252
encryption: Optional[str],
252253
workgroup: Optional[str],
253254
kms_key: Optional[str],
254-
wg_config: _WorkGroupConfig,
255255
alt_database: Optional[str],
256256
name: Optional[str],
257257
ctas_bucketing_info: Optional[Tuple[List[str], int]],
@@ -260,52 +260,25 @@ def _resolve_query_without_cache_ctas(
260260
boto3_session: boto3.Session,
261261
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
262262
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
263-
path: str = f"{s3_output}/{name}"
264-
ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n"
265263
fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"'
266-
bucketing_str = (
267-
(f",\n" f" bucketed_by = ARRAY{ctas_bucketing_info[0]},\n" f" bucket_count = {ctas_bucketing_info[1]}")
268-
if ctas_bucketing_info
269-
else ""
270-
)
271-
sql = (
272-
f"CREATE TABLE {fully_qualified_name}\n"
273-
f"WITH(\n"
274-
f" format = 'Parquet',\n"
275-
f" parquet_compression = 'SNAPPY'"
276-
f"{bucketing_str}"
277-
f"{ext_location}"
278-
f") AS\n"
279-
f"{sql}"
264+
ctas_query_info: Dict[str, str] = create_ctas_table(
265+
sql=sql,
266+
database=database,
267+
ctas_table=name,
268+
ctas_database=alt_database,
269+
bucketing_info=ctas_bucketing_info,
270+
data_source=data_source,
271+
s3_output=s3_output,
272+
workgroup=workgroup,
273+
encryption=encryption,
274+
kms_key=kms_key,
275+
boto3_session=boto3_session,
280276
)
281-
_logger.debug("sql: %s", sql)
282-
try:
283-
query_id: str = _start_query_execution(
284-
sql=sql,
285-
wg_config=wg_config,
286-
database=database,
287-
data_source=data_source,
288-
s3_output=s3_output,
289-
workgroup=workgroup,
290-
encryption=encryption,
291-
kms_key=kms_key,
292-
boto3_session=boto3_session,
293-
)
294-
except botocore.exceptions.ClientError as ex:
295-
error: Dict[str, Any] = ex.response["Error"]
296-
if error["Code"] == "InvalidRequestException" and "Exception parsing query" in error["Message"]:
297-
raise exceptions.InvalidCtasApproachQuery(
298-
"Is not possible to wrap this query into a CTAS statement. Please use ctas_approach=False."
299-
)
300-
if error["Code"] == "InvalidRequestException" and "extraneous input" in error["Message"]:
301-
raise exceptions.InvalidCtasApproachQuery(
302-
"Is not possible to wrap this query into a CTAS statement. Please use ctas_approach=False."
303-
)
304-
raise ex
305-
_logger.debug("query_id: %s", query_id)
277+
ctas_query_id: str = ctas_query_info["ctas_query_id"]
278+
_logger.debug("ctas_query_id: %s", ctas_query_id)
306279
try:
307280
query_metadata: _QueryMetadata = _get_query_metadata(
308-
query_execution_id=query_id,
281+
query_execution_id=ctas_query_id,
309282
boto3_session=boto3_session,
310283
categories=categories,
311284
metadata_cache_manager=_cache_manager,
@@ -482,7 +455,6 @@ def _resolve_query_without_cache(
482455
encryption=encryption,
483456
workgroup=workgroup,
484457
kms_key=kms_key,
485-
wg_config=wg_config,
486458
alt_database=ctas_database_name,
487459
name=name,
488460
ctas_bucketing_info=ctas_bucketing_info,

awswrangler/athena/_utils.py

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import logging
44
import pprint
55
import time
6+
import uuid
67
import warnings
78
from decimal import Decimal
8-
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Union, cast
9+
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union, cast
910

1011
import boto3
1112
import botocore.exceptions
1213
import pandas as pd
1314

14-
from awswrangler import _data_types, _utils, exceptions, s3, sts
15+
from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts
1516
from awswrangler._config import apply_configs
1617

1718
from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results, _LocalMetadataCacheManager
@@ -640,6 +641,143 @@ def describe_table(
640641
return _parse_describe_table(raw_result)
641642

642643

644+
@apply_configs
645+
def create_ctas_table(
646+
sql: str,
647+
database: str,
648+
ctas_table: Optional[str] = None,
649+
ctas_database: Optional[str] = None,
650+
s3_output: Optional[str] = None,
651+
storage_format: Optional[str] = None,
652+
write_compression: Optional[str] = None,
653+
partitioning_info: Optional[List[str]] = None,
654+
bucketing_info: Optional[Tuple[List[str], int]] = None,
655+
field_delimiter: Optional[str] = None,
656+
schema_only: bool = False,
657+
workgroup: Optional[str] = None,
658+
data_source: Optional[str] = None,
659+
encryption: Optional[str] = None,
660+
kms_key: Optional[str] = None,
661+
boto3_session: Optional[boto3.Session] = None,
662+
) -> Dict[str, str]:
663+
"""Create a new table populated with the results of a SELECT query.
664+
665+
https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html
666+
667+
Parameters
668+
----------
669+
sql : str
670+
SELECT SQL query.
671+
database : str
672+
The name of the database where the original table is stored.
673+
ctas_table : Optional[str], optional
674+
The name of the CTAS table.
675+
If None, a random string is used.
676+
ctas_database : Optional[str], optional
677+
The name of the alternative database where the CTAS table should be stored.
678+
If None, `database` is used, that is the CTAS table is stored in the same database as the original table.
679+
s3_output : Optional[str], optional
680+
The output Amazon S3 path.
681+
If None, either the Athena workgroup or client-side location setting is used.
682+
If a workgroup enforces a query results location, then it overrides this argument.
683+
storage_format : Optional[str], optional
684+
The storage format for the CTAS query results, such as ORC, PARQUET, AVRO, JSON, or TEXTFILE.
685+
PARQUET by default.
686+
write_compression : Optional[str], optional
687+
The compression type to use for any storage format that allows compression to be specified.
688+
partitioning_info : Optional[List[str]], optional
689+
A list of columns by which the CTAS table will be partitioned.
690+
bucketing_info : Optional[Tuple[List[str], int]], optional
691+
Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the
692+
second element.
693+
Only `str`, `int` and `bool` are supported as column data types for bucketing.
694+
field_delimiter : Optional[str], optional
695+
The single-character field delimiter for files in CSV, TSV, and text files.
696+
schema_only : bool, optional
697+
_description_, by default False
698+
workgroup : Optional[str], optional
699+
Athena workgroup.
700+
data_source : Optional[str], optional
701+
Data Source / Catalog name. If None, 'AwsDataCatalog' is used.
702+
encryption : str, optional
703+
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Note: 'CSE_KMS' is not supported.
704+
kms_key : str, optional
705+
For SSE-KMS, this is the KMS key ARN or ID.
706+
boto3_session : Optional[boto3.Session], optional
707+
Boto3 Session. The default boto3 session is used if boto3_session is None.
708+
709+
Returns
710+
-------
711+
Dict[str, str]
712+
A dictionary with the ID of the query, and the CTAS database and table names
713+
"""
714+
ctas_table = catalog.sanitize_table_name(ctas_table) if ctas_table else f"temp_table_{uuid.uuid4().hex}"
715+
ctas_database = ctas_database if ctas_database else database
716+
fully_qualified_name = f'"{ctas_database}"."{ctas_table}"'
717+
718+
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
719+
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
720+
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
721+
# If the workgroup enforces an external location, then it overrides the user supplied argument
722+
external_location_str: str = (
723+
f" external_location = '{s3_output}/{ctas_table}',\n" if (not wg_config.enforced) and (s3_output) else ""
724+
)
725+
726+
# At least one property must be specified within `WITH()` in the query. We default to `PARQUET` for `storage_format`
727+
storage_format_str: str = f""" format = '{storage_format.upper() if storage_format else "PARQUET"}'"""
728+
write_compression_str: str = (
729+
f" write_compression = '{write_compression.upper()}',\n" if write_compression else ""
730+
)
731+
partitioning_str: str = f" partitioned_by = ARRAY{partitioning_info},\n" if partitioning_info else ""
732+
bucketing_str: str = (
733+
f" bucketed_by = ARRAY{bucketing_info[0]},\n bucket_count = {bucketing_info[1]},\n"
734+
if bucketing_info
735+
else ""
736+
)
737+
field_delimiter_str: str = f" field_delimiter = '{field_delimiter}',\n" if field_delimiter else ""
738+
schema_only_str: str = "\nWITH NO DATA" if schema_only else ""
739+
740+
ctas_sql = (
741+
f"CREATE TABLE {fully_qualified_name}\n"
742+
f"WITH(\n"
743+
f"{external_location_str}"
744+
f"{partitioning_str}"
745+
f"{bucketing_str}"
746+
f"{field_delimiter_str}"
747+
f"{write_compression_str}"
748+
f"{storage_format_str}"
749+
f")\n"
750+
f"AS {sql}"
751+
f"{schema_only_str}"
752+
)
753+
_logger.debug("ctas sql: %s", ctas_sql)
754+
755+
try:
756+
query_id: str = _start_query_execution(
757+
sql=ctas_sql,
758+
wg_config=wg_config,
759+
database=database,
760+
data_source=data_source,
761+
s3_output=s3_output,
762+
workgroup=workgroup,
763+
encryption=encryption,
764+
kms_key=kms_key,
765+
boto3_session=boto3_session,
766+
)
767+
except botocore.exceptions.ClientError as ex:
768+
error: Dict[str, Any] = ex.response["Error"]
769+
if error["Code"] == "InvalidRequestException" and "Exception parsing query" in error["Message"]:
770+
raise exceptions.InvalidCtasApproachQuery(
771+
f"It is not possible to wrap this query into a CTAS statement. Root error message: {error['Message']}"
772+
)
773+
if error["Code"] == "InvalidRequestException" and "extraneous input" in error["Message"]:
774+
raise exceptions.InvalidCtasApproachQuery(
775+
f"It is not possible to wrap this query into a CTAS statement. Root error message: {error['Message']}"
776+
)
777+
raise ex
778+
return {"ctas_database": ctas_database, "ctas_table": ctas_table, "ctas_query_id": query_id}
779+
780+
643781
@apply_configs
644782
def show_create_table(
645783
table: str,

docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Amazon Athena
111111
:toctree: stubs
112112

113113
create_athena_bucket
114+
create_ctas_table
114115
get_query_columns_types
115116
get_query_execution
116117
get_named_query_statement

tests/_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from datetime import datetime
44
from decimal import Decimal
5-
from typing import Iterator
5+
from typing import Dict, Iterator
66

77
import boto3
88
import botocore.exceptions
@@ -501,6 +501,14 @@ def ensure_data_types_csv(df, governed=False):
501501
assert str(df["par1"].dtype) == "string"
502502

503503

504+
def ensure_athena_ctas_table(ctas_query_info: Dict[str, str], boto3_session: boto3.Session) -> None:
505+
query_metadata = wr.athena._utils._get_query_metadata(
506+
query_execution_id=ctas_query_info["ctas_query_id"], boto3_session=boto3_session
507+
)
508+
assert query_metadata.raw_payload["Status"]["State"] == "SUCCEEDED"
509+
wr.catalog.delete_table_if_exists(table=ctas_query_info["ctas_table"], database=ctas_query_info["ctas_database"])
510+
511+
504512
def ensure_athena_query_metadata(df, ctas_approach=True, encrypted=False):
505513
assert df.query_metadata is not None
506514
assert isinstance(df.query_metadata, dict)

tests/test_athena.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import awswrangler as wr
1111

1212
from ._utils import (
13+
ensure_athena_ctas_table,
1314
ensure_athena_query_metadata,
1415
ensure_data_types,
1516
ensure_data_types_category,
@@ -148,6 +149,57 @@ def test_athena_read_sql_ctas_bucketing(path, path2, glue_table, glue_table2, gl
148149
assert df_ctas.equals(df_no_ctas)
149150

150151

152+
def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_ctas_database, kms_key):
153+
boto3_session = boto3.DEFAULT_SESSION
154+
wr.s3.to_parquet(
155+
df=get_df_list(),
156+
path=path,
157+
index=False,
158+
use_threads=True,
159+
dataset=True,
160+
mode="overwrite",
161+
database=glue_database,
162+
table=glue_table,
163+
partition_cols=["par0", "par1"],
164+
)
165+
166+
# Select *
167+
ctas_query_info = wr.athena.create_ctas_table(
168+
sql=f"select * from {glue_table}",
169+
database=glue_database,
170+
encryption="SSE_KMS",
171+
kms_key=kms_key,
172+
)
173+
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)
174+
175+
# Schema only (i.e. WITH NO DATA)
176+
ctas_query_info = wr.athena.create_ctas_table(
177+
sql=f"select * from {glue_table}",
178+
database=glue_database,
179+
ctas_table=glue_table2,
180+
schema_only=True,
181+
)
182+
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)
183+
184+
# Convert to new data storage and compression
185+
ctas_query_info = wr.athena.create_ctas_table(
186+
sql=f"select string, bool from {glue_table}",
187+
database=glue_database,
188+
storage_format="avro",
189+
write_compression="snappy",
190+
)
191+
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)
192+
193+
# Partition and save to CTAS database
194+
ctas_query_info = wr.athena.create_ctas_table(
195+
sql=f"select * from {glue_table}",
196+
database=glue_database,
197+
ctas_database=glue_ctas_database,
198+
partitioning_info=["par0", "par1"],
199+
)
200+
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)
201+
202+
151203
def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1):
152204
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
153205
wr.s3.to_parquet(

0 commit comments

Comments
 (0)