Skip to content

Commit 10ea9e8

Browse files
authored
Merge pull request #221 from awslabs/athena-args
Add keep_files and ctas_temp_table_name to wr.athena.read_*().
2 parents 97f8763 + 458bf26 commit 10ea9e8

File tree

3 files changed

+130
-28
lines changed

3 files changed

+130
-28
lines changed

awswrangler/athena.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _fix_csv_types(df: pd.DataFrame, parse_dates: List[str], binaries: List[str]
370370
return df
371371

372372

373-
def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
373+
def read_sql_query( # pylint: disable=too-many-branches,too-many-locals,too-many-return-statements,too-many-statements
374374
sql: str,
375375
database: str,
376376
ctas_approach: bool = True,
@@ -380,6 +380,8 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
380380
workgroup: Optional[str] = None,
381381
encryption: Optional[str] = None,
382382
kms_key: Optional[str] = None,
383+
keep_files: bool = True,
384+
ctas_temp_table_name: Optional[str] = None,
383385
use_threads: bool = True,
384386
boto3_session: Optional[boto3.Session] = None,
385387
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
@@ -454,6 +456,12 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
454456
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
455457
kms_key : str, optional
456458
For SSE-KMS, this is the KMS key ARN or ID.
459+
keep_files : bool
460+
Should Wrangler delete or keep the staging files produced by Athena?
461+
ctas_temp_table_name : str, optional
462+
The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
463+
If None, it will use the follow random pattern: `f"temp_table_{pyarrow.compat.guid()}"`.
464+
On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`.
457465
use_threads : bool
458466
True to enable concurrent requests, False to disable multiple threads.
459467
If enabled os.cpu_count() will be used as the max number of threads.
@@ -477,7 +485,10 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
477485
_s3_output = _s3_output[:-1] if _s3_output[-1] == "/" else _s3_output
478486
name: str = ""
479487
if ctas_approach is True:
480-
name = f"temp_table_{pa.compat.guid()}"
488+
if ctas_temp_table_name is not None:
489+
name = catalog.sanitize_table_name(ctas_temp_table_name)
490+
else:
491+
name = f"temp_table_{pa.compat.guid()}"
481492
path: str = f"{_s3_output}/{name}"
482493
ext_location: str = "\n" if wg_config["enforced"] is True else f",\n external_location = '{path}'\n"
483494
sql = (
@@ -506,25 +517,34 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
506517
reason: str = query_response["QueryExecution"]["Status"]["StateChangeReason"]
507518
message_error: str = f"Query error: {reason}"
508519
raise exceptions.AthenaQueryError(message_error)
509-
dfs: Union[pd.DataFrame, Iterator[pd.DataFrame]]
520+
ret: Union[pd.DataFrame, Iterator[pd.DataFrame]]
510521
if ctas_approach is True:
511522
catalog.delete_table_if_exists(database=database, table=name, boto3_session=session)
512523
manifest_path: str = f"{_s3_output}/tables/{query_id}-manifest.csv"
524+
metadata_path: str = f"{_s3_output}/tables/{query_id}.metadata"
513525
_logger.debug("manifest_path: %s", manifest_path)
526+
_logger.debug("metadata_path: %s", metadata_path)
527+
s3.wait_objects_exist(paths=[manifest_path, metadata_path], use_threads=False, boto3_session=session)
514528
paths: List[str] = _extract_ctas_manifest_paths(path=manifest_path, boto3_session=session)
515529
chunked: Union[bool, int] = False if chunksize is None else chunksize
516530
_logger.debug("chunked: %s", chunked)
517531
if not paths:
518532
if chunked is False:
519-
dfs = pd.DataFrame()
520-
else:
521-
dfs = _utils.empty_generator()
522-
else:
523-
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
524-
dfs = s3.read_parquet(
525-
path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories
526-
)
527-
return dfs
533+
return pd.DataFrame()
534+
return _utils.empty_generator()
535+
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
536+
ret = s3.read_parquet(
537+
path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories
538+
)
539+
paths_delete: List[str] = paths + [manifest_path, metadata_path]
540+
_logger.debug(type(ret))
541+
if chunked is False:
542+
if keep_files is False:
543+
s3.delete_objects(path=paths_delete, use_threads=use_threads, boto3_session=session)
544+
return ret
545+
if keep_files is False:
546+
return _delete_after_iterate(dfs=ret, paths=paths_delete, use_threads=use_threads, boto3_session=session)
547+
return ret
528548
dtype, parse_timestamps, parse_dates, converters, binaries = _get_query_metadata(
529549
query_execution_id=query_id, categories=categories, boto3_session=session
530550
)
@@ -547,10 +567,26 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
547567
boto3_session=session,
548568
)
549569
_logger.debug("Start type casting...")
550-
if chunksize is None:
551-
return _fix_csv_types(df=ret, parse_dates=parse_dates, binaries=binaries)
552570
_logger.debug(type(ret))
553-
return _fix_csv_types_generator(dfs=ret, parse_dates=parse_dates, binaries=binaries)
571+
if chunksize is None:
572+
df = _fix_csv_types(df=ret, parse_dates=parse_dates, binaries=binaries)
573+
if keep_files is False:
574+
s3.delete_objects(path=[path, f"{path}.metadata"], use_threads=use_threads, boto3_session=session)
575+
return df
576+
dfs = _fix_csv_types_generator(dfs=ret, parse_dates=parse_dates, binaries=binaries)
577+
if keep_files is False:
578+
return _delete_after_iterate(
579+
dfs=dfs, paths=[path, f"{path}.metadata"], use_threads=use_threads, boto3_session=session
580+
)
581+
return dfs
582+
583+
584+
def _delete_after_iterate(
585+
dfs: Iterator[pd.DataFrame], paths: List[str], use_threads: bool, boto3_session: boto3.Session
586+
) -> Iterator[pd.DataFrame]:
587+
for df in dfs:
588+
yield df
589+
s3.delete_objects(path=paths, use_threads=use_threads, boto3_session=boto3_session)
554590

555591

556592
def stop_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> None:
@@ -638,6 +674,8 @@ def read_sql_table(
638674
workgroup: Optional[str] = None,
639675
encryption: Optional[str] = None,
640676
kms_key: Optional[str] = None,
677+
keep_files: bool = True,
678+
ctas_temp_table_name: Optional[str] = None,
641679
use_threads: bool = True,
642680
boto3_session: Optional[boto3.Session] = None,
643681
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
@@ -712,6 +750,12 @@ def read_sql_table(
712750
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
713751
kms_key : str, optional
714752
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
753+
keep_files : bool
754+
Should Wrangler delete or keep the staging files produced by Athena?
755+
ctas_temp_table_name : str, optional
756+
The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
757+
If None, it will use the follow random pattern: `f"temp_table_{pyarrow.compat.guid()}"`.
758+
On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`.
715759
use_threads : bool
716760
True to enable concurrent requests, False to disable multiple threads.
717761
If enabled os.cpu_count() will be used as the max number of threads.
@@ -740,6 +784,8 @@ def read_sql_table(
740784
workgroup=workgroup,
741785
encryption=encryption,
742786
kms_key=kms_key,
787+
keep_files=keep_files,
788+
ctas_temp_table_name=ctas_temp_table_name,
743789
use_threads=use_threads,
744790
boto3_session=boto3_session,
745791
)

awswrangler/torch.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ class _BaseS3Dataset:
2828
def __init__(
2929
self, path: Union[str, List[str]], suffix: Optional[str] = None, boto3_session: Optional[boto3.Session] = None
3030
):
31-
"""PyTorch Map-Style S3 Dataset.
31+
r"""PyTorch Map-Style S3 Dataset.
3232
3333
Parameters
3434
----------
3535
path : Union[str, List[str]]
3636
S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]).
3737
suffix: str, optional
38-
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png).
38+
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png).
3939
boto3_session : boto3.Session(), optional
4040
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
4141
@@ -160,7 +160,7 @@ def __init__(
160160
suffix: Optional[str] = None,
161161
boto3_session: Optional[boto3.Session] = None,
162162
):
163-
"""PyTorch Amazon S3 Lambda Dataset.
163+
r"""PyTorch Amazon S3 Lambda Dataset.
164164
165165
Parameters
166166
----------
@@ -171,7 +171,7 @@ def __init__(
171171
label_fn: Callable
172172
Function that receives object path (str) and return a torch.Tensor
173173
suffix: str, optional
174-
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png).
174+
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png).
175175
boto3_session : boto3.Session(), optional
176176
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
177177
@@ -212,7 +212,7 @@ def __init__(
212212
suffix: Optional[str] = None,
213213
boto3_session: Optional[boto3.Session] = None,
214214
):
215-
"""PyTorch Amazon S3 Audio Dataset.
215+
r"""PyTorch Amazon S3 Audio Dataset.
216216
217217
Read individual WAV audio files stores in Amazon S3 and return
218218
them as torch tensors.
@@ -237,7 +237,7 @@ def __init__(
237237
path : Union[str, List[str]]
238238
S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]).
239239
suffix: str, optional
240-
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png).
240+
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png).
241241
boto3_session : boto3.Session(), optional
242242
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
243243
@@ -302,7 +302,7 @@ class ImageS3Dataset(_S3PartitionedDataset):
302302
"""PyTorch Amazon S3 Image Dataset."""
303303

304304
def __init__(self, path: Union[str, List[str]], suffix: str, boto3_session: boto3.Session):
305-
"""PyTorch Amazon S3 Image Dataset.
305+
r"""PyTorch Amazon S3 Image Dataset.
306306
307307
ImageS3Dataset assumes images are patitioned (within class=<value> folders) in Amazon S3.
308308
Each lisited object will be loaded by default Pillow library.
@@ -327,7 +327,7 @@ def __init__(self, path: Union[str, List[str]], suffix: str, boto3_session: boto
327327
path : Union[str, List[str]]
328328
S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]).
329329
suffix: str, optional
330-
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png).
330+
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png).
331331
boto3_session : boto3.Session(), optional
332332
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
333333
@@ -350,14 +350,14 @@ def _data_fn(self, data: io.BytesIO) -> Any:
350350

351351

352352
class S3IterableDataset(IterableDataset, _BaseS3Dataset): # pylint: disable=abstract-method
353-
"""PyTorch Amazon S3 Iterable Dataset.
353+
r"""PyTorch Amazon S3 Iterable Dataset.
354354
355355
Parameters
356356
----------
357357
path : Union[str, List[str]]
358358
S3 prefix (e.g. s3://bucket/prefix) or list of S3 objects paths (e.g. [s3://bucket/key0, s3://bucket/key1]).
359359
suffix: str, optional
360-
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://*.png).
360+
S3 suffix filtering of object keys (i.e. suffix=".png" -> s3://\*.png).
361361
boto3_session : boto3.Session(), optional
362362
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
363363

testing/test_awswrangler/test_data_lake.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,51 @@ def test_athena_ctas(bucket, database, kms_key):
191191
encryption="SSE_KMS",
192192
kms_key=kms_key,
193193
s3_output=f"s3://{bucket}/test_athena_ctas_result",
194+
keep_files=False,
194195
)
195196
assert len(df.index) == 3
196197
ensure_data_types(df=df, has_list=True)
198+
temp_table = "test_athena_ctas2"
199+
s3_output = f"s3://{bucket}/s3_output/"
200+
final_destination = f"{s3_output}{temp_table}/"
201+
202+
# keep_files=False
203+
wr.s3.delete_objects(path=s3_output)
197204
dfs = wr.athena.read_sql_query(
198-
sql=f"SELECT * FROM test_athena_ctas", database=database, ctas_approach=True, chunksize=1
205+
sql=f"SELECT * FROM test_athena_ctas",
206+
database=database,
207+
ctas_approach=True,
208+
chunksize=1,
209+
keep_files=False,
210+
ctas_temp_table_name=temp_table,
211+
s3_output=s3_output,
199212
)
213+
assert wr.catalog.does_table_exist(database=database, table=temp_table) is False
214+
assert len(wr.s3.list_objects(path=s3_output)) > 2
215+
assert len(wr.s3.list_objects(path=final_destination)) > 0
200216
for df in dfs:
201217
ensure_data_types(df=df, has_list=True)
218+
assert len(wr.s3.list_objects(path=s3_output)) == 0
219+
220+
# keep_files=True
221+
wr.s3.delete_objects(path=s3_output)
222+
dfs = wr.athena.read_sql_query(
223+
sql=f"SELECT * FROM test_athena_ctas",
224+
database=database,
225+
ctas_approach=True,
226+
chunksize=2,
227+
keep_files=True,
228+
ctas_temp_table_name=temp_table,
229+
s3_output=s3_output,
230+
)
231+
assert wr.catalog.does_table_exist(database=database, table=temp_table) is False
232+
assert len(wr.s3.list_objects(path=s3_output)) > 2
233+
assert len(wr.s3.list_objects(path=final_destination)) > 0
234+
for df in dfs:
235+
ensure_data_types(df=df, has_list=True)
236+
assert len(wr.s3.list_objects(path=s3_output)) > 2
237+
238+
# Cleaning Up
202239
wr.catalog.delete_table_if_exists(database=database, table="test_athena_ctas")
203240
wr.s3.delete_objects(path=paths)
204241
wr.s3.wait_objects_not_exist(paths=paths)
@@ -227,12 +264,17 @@ def test_athena(bucket, database, kms_key, workgroup0, workgroup1):
227264
encryption="SSE_KMS",
228265
kms_key=kms_key,
229266
workgroup=workgroup0,
267+
keep_files=False,
230268
)
231269
for df2 in dfs:
232270
print(df2)
233271
ensure_data_types(df=df2)
234272
df = wr.athena.read_sql_query(
235-
sql="SELECT * FROM __test_athena", database=database, ctas_approach=False, workgroup=workgroup1
273+
sql="SELECT * FROM __test_athena",
274+
database=database,
275+
ctas_approach=False,
276+
workgroup=workgroup1,
277+
keep_files=False,
236278
)
237279
assert len(df.index) == 3
238280
ensure_data_types(df=df)
@@ -1195,9 +1237,23 @@ def test_athena_encryption(
11951237
df=df, path=path, dataset=True, mode="overwrite", database=database, table=table, s3_additional_kwargs=None
11961238
)["paths"]
11971239
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
1240+
temp_table = table + "2"
1241+
s3_output = f"s3://{bucket}/encryptio_s3_output/"
1242+
final_destination = f"{s3_output}{temp_table}/"
1243+
wr.s3.delete_objects(path=final_destination)
11981244
df2 = wr.athena.read_sql_table(
1199-
table=table, ctas_approach=True, database=database, encryption=encryption, workgroup=workgroup, kms_key=kms_key
1245+
table=table,
1246+
ctas_approach=True,
1247+
database=database,
1248+
encryption=encryption,
1249+
workgroup=workgroup,
1250+
kms_key=kms_key,
1251+
keep_files=True,
1252+
ctas_temp_table_name=temp_table,
1253+
s3_output=s3_output,
12001254
)
1255+
assert wr.catalog.does_table_exist(database=database, table=temp_table) is False
1256+
assert len(wr.s3.list_objects(path=s3_output)) > 2
12011257
print(df2)
12021258
assert len(df2.index) == 2
12031259
assert len(df2.columns) == 2

0 commit comments

Comments
 (0)