Skip to content

Commit 14b240c

Browse files
authored
Adding ctas_database_name argument (#595)
1 parent fbea66b commit 14b240c

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

awswrangler/athena/_read.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,17 @@ def _resolve_query_without_cache_ctas(
374374
workgroup: Optional[str],
375375
kms_key: Optional[str],
376376
wg_config: _WorkGroupConfig,
377+
alt_database: Optional[str],
377378
name: Optional[str],
378379
use_threads: bool,
379380
s3_additional_kwargs: Optional[Dict[str, Any]],
380381
boto3_session: boto3.Session,
381382
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
382383
path: str = f"{s3_output}/{name}"
383384
ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n"
385+
fully_qualified_name: str = f'"{alt_database}"."{name}"' if alt_database else f'"{database}"."{name}"'
384386
sql = (
385-
f'CREATE TABLE "{name}"\n'
387+
f"CREATE TABLE {fully_qualified_name}\n"
386388
f"WITH(\n"
387389
f" format = 'Parquet',\n"
388390
f" parquet_compression = 'SNAPPY'"
@@ -507,6 +509,7 @@ def _resolve_query_without_cache(
507509
encryption: Optional[str],
508510
kms_key: Optional[str],
509511
keep_files: bool,
512+
ctas_database_name: Optional[str],
510513
ctas_temp_table_name: Optional[str],
511514
use_threads: bool,
512515
s3_additional_kwargs: Optional[Dict[str, Any]],
@@ -538,6 +541,7 @@ def _resolve_query_without_cache(
538541
workgroup=workgroup,
539542
kms_key=kms_key,
540543
wg_config=wg_config,
544+
alt_database=ctas_database_name,
541545
name=name,
542546
use_threads=use_threads,
543547
s3_additional_kwargs=s3_additional_kwargs,
@@ -575,6 +579,7 @@ def read_sql_query(
575579
encryption: Optional[str] = None,
576580
kms_key: Optional[str] = None,
577581
keep_files: bool = True,
582+
ctas_database_name: Optional[str] = None,
578583
ctas_temp_table_name: Optional[str] = None,
579584
use_threads: bool = True,
580585
boto3_session: Optional[boto3.Session] = None,
@@ -709,6 +714,9 @@ def read_sql_query(
709714
For SSE-KMS, this is the KMS key ARN or ID.
710715
keep_files : bool
711716
Should Wrangler delete or keep the staging files produced by Athena?
717+
ctas_database_name : str, optional
718+
The name of the alternative database where the CTAS temporary table is stored.
719+
If None, the default `database` is used.
712720
ctas_temp_table_name : str, optional
713721
The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
714722
If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex()}"`.
@@ -820,6 +828,7 @@ def read_sql_query(
820828
encryption=encryption,
821829
kms_key=kms_key,
822830
keep_files=keep_files,
831+
ctas_database_name=ctas_database_name,
823832
ctas_temp_table_name=ctas_temp_table_name,
824833
use_threads=use_threads,
825834
s3_additional_kwargs=s3_additional_kwargs,
@@ -839,6 +848,7 @@ def read_sql_table(
839848
encryption: Optional[str] = None,
840849
kms_key: Optional[str] = None,
841850
keep_files: bool = True,
851+
ctas_database_name: Optional[str] = None,
842852
ctas_temp_table_name: Optional[str] = None,
843853
use_threads: bool = True,
844854
boto3_session: Optional[boto3.Session] = None,
@@ -967,6 +977,9 @@ def read_sql_table(
967977
For SSE-KMS, this is the KMS key ARN or ID.
968978
keep_files : bool
969979
Should Wrangler delete or keep the staging files produced by Athena?
980+
ctas_database_name : str, optional
981+
The name of the alternative database where the CTAS temporary table is stored.
982+
If None, the default `database` is used.
970983
ctas_temp_table_name : str, optional
971984
The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
972985
If None, it will use the follow random pattern: `f"temp_table_{uuid.uuid4().hex}"`.
@@ -1027,6 +1040,7 @@ def read_sql_table(
10271040
encryption=encryption,
10281041
kms_key=kms_key,
10291042
keep_files=keep_files,
1043+
ctas_database_name=ctas_database_name,
10301044
ctas_temp_table_name=ctas_temp_table_name,
10311045
use_threads=use_threads,
10321046
boto3_session=boto3_session,

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ def redshift_external_schema(cloudformation_outputs, databases_parameters, glue_
170170
return "aws_data_wrangler_external"
171171

172172

173+
@pytest.fixture(scope="function")
174+
def glue_ctas_database():
175+
name = f"db_{get_time_str_with_random_suffix()}"
176+
print(f"Database name: {name}")
177+
wr.catalog.create_database(name=name)
178+
yield name
179+
wr.catalog.delete_database(name=name)
180+
print(f"Database {name} deleted.")
181+
182+
173183
@pytest.fixture(scope="function")
174184
def glue_table(glue_database: str) -> None:
175185
name = f"tbl_{get_time_str_with_random_suffix()}"

tests/test_athena.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
2727

2828

29-
def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, kms_key):
29+
def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, glue_ctas_database, kms_key):
3030
df = get_df_list()
3131
columns_types, partitions_types = wr.catalog.extract_athena_types(df=df, partition_cols=["par0", "par1"])
3232
assert len(columns_types) == 17
@@ -102,6 +102,26 @@ def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database,
102102
ensure_athena_query_metadata(df=df, ctas_approach=True, encrypted=False)
103103
assert len(wr.s3.list_objects(path=path3)) > 2
104104

105+
# ctas_database_name
106+
wr.s3.delete_objects(path=path3)
107+
dfs = wr.athena.read_sql_query(
108+
sql=f"SELECT * FROM {glue_table}",
109+
database=glue_database,
110+
ctas_approach=True,
111+
chunksize=1,
112+
keep_files=False,
113+
ctas_database_name=glue_ctas_database,
114+
ctas_temp_table_name=glue_table2,
115+
s3_output=path3,
116+
)
117+
assert wr.catalog.does_table_exist(database=glue_ctas_database, table=glue_table2) is True
118+
assert len(wr.s3.list_objects(path=path3)) > 2
119+
assert len(wr.s3.list_objects(path=final_destination)) > 0
120+
for df in dfs:
121+
ensure_data_types(df=df, has_list=True)
122+
ensure_athena_query_metadata(df=df, ctas_approach=True, encrypted=False)
123+
assert len(wr.s3.list_objects(path=path3)) == 0
124+
105125

106126
def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1):
107127
table = f"__{glue_table}"

0 commit comments

Comments
 (0)