@@ -370,7 +370,7 @@ def _fix_csv_types(df: pd.DataFrame, parse_dates: List[str], binaries: List[str]
370370 return df
371371
372372
373- def read_sql_query ( # pylint: disable=too-many-branches,too-many-locals
373+ def read_sql_query ( # pylint: disable=too-many-branches,too-many-locals,too-many-return-statements,too-many-statements
374374 sql : str ,
375375 database : str ,
376376 ctas_approach : bool = True ,
@@ -380,6 +380,8 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
380380 workgroup : Optional [str ] = None ,
381381 encryption : Optional [str ] = None ,
382382 kms_key : Optional [str ] = None ,
383+ keep_files : bool = True ,
384+ ctas_temp_table_name : Optional [str ] = None ,
383385 use_threads : bool = True ,
384386 boto3_session : Optional [boto3 .Session ] = None ,
385387) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
@@ -454,6 +456,12 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
454456 Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
455457 kms_key : str, optional
456458 For SSE-KMS, this is the KMS key ARN or ID.
459+ keep_files : bool
460+ Should Wrangler delete or keep the staging files produced by Athena?
461+ ctas_temp_table_name : str, optional
462+ The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
463+ If None, it will use the follow random pattern: `f"temp_table_{pyarrow.compat.guid()}"`.
464+ On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`.
457465 use_threads : bool
458466 True to enable concurrent requests, False to disable multiple threads.
459467 If enabled os.cpu_count() will be used as the max number of threads.
@@ -477,7 +485,10 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
477485 _s3_output = _s3_output [:- 1 ] if _s3_output [- 1 ] == "/" else _s3_output
478486 name : str = ""
479487 if ctas_approach is True :
480- name = f"temp_table_{ pa .compat .guid ()} "
488+ if ctas_temp_table_name is not None :
489+ name = catalog .sanitize_table_name (ctas_temp_table_name )
490+ else :
491+ name = f"temp_table_{ pa .compat .guid ()} "
481492 path : str = f"{ _s3_output } /{ name } "
482493 ext_location : str = "\n " if wg_config ["enforced" ] is True else f",\n external_location = '{ path } '\n "
483494 sql = (
@@ -506,25 +517,34 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
506517 reason : str = query_response ["QueryExecution" ]["Status" ]["StateChangeReason" ]
507518 message_error : str = f"Query error: { reason } "
508519 raise exceptions .AthenaQueryError (message_error )
509- dfs : Union [pd .DataFrame , Iterator [pd .DataFrame ]]
520+ ret : Union [pd .DataFrame , Iterator [pd .DataFrame ]]
510521 if ctas_approach is True :
511522 catalog .delete_table_if_exists (database = database , table = name , boto3_session = session )
512523 manifest_path : str = f"{ _s3_output } /tables/{ query_id } -manifest.csv"
524+ metadata_path : str = f"{ _s3_output } /tables/{ query_id } .metadata"
513525 _logger .debug ("manifest_path: %s" , manifest_path )
526+ _logger .debug ("metadata_path: %s" , metadata_path )
527+ s3 .wait_objects_exist (paths = [manifest_path , metadata_path ], use_threads = False , boto3_session = session )
514528 paths : List [str ] = _extract_ctas_manifest_paths (path = manifest_path , boto3_session = session )
515529 chunked : Union [bool , int ] = False if chunksize is None else chunksize
516530 _logger .debug ("chunked: %s" , chunked )
517531 if not paths :
518532 if chunked is False :
519- dfs = pd .DataFrame ()
520- else :
521- dfs = _utils .empty_generator ()
522- else :
523- s3 .wait_objects_exist (paths = paths , use_threads = False , boto3_session = session )
524- dfs = s3 .read_parquet (
525- path = paths , use_threads = use_threads , boto3_session = session , chunked = chunked , categories = categories
526- )
527- return dfs
533+ return pd .DataFrame ()
534+ return _utils .empty_generator ()
535+ s3 .wait_objects_exist (paths = paths , use_threads = False , boto3_session = session )
536+ ret = s3 .read_parquet (
537+ path = paths , use_threads = use_threads , boto3_session = session , chunked = chunked , categories = categories
538+ )
539+ paths_delete : List [str ] = paths + [manifest_path , metadata_path ]
540+ _logger .debug (type (ret ))
541+ if chunked is False :
542+ if keep_files is False :
543+ s3 .delete_objects (path = paths_delete , use_threads = use_threads , boto3_session = session )
544+ return ret
545+ if keep_files is False :
546+ return _delete_after_iterate (dfs = ret , paths = paths_delete , use_threads = use_threads , boto3_session = session )
547+ return ret
528548 dtype , parse_timestamps , parse_dates , converters , binaries = _get_query_metadata (
529549 query_execution_id = query_id , categories = categories , boto3_session = session
530550 )
@@ -547,10 +567,26 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
547567 boto3_session = session ,
548568 )
549569 _logger .debug ("Start type casting..." )
550- if chunksize is None :
551- return _fix_csv_types (df = ret , parse_dates = parse_dates , binaries = binaries )
552570 _logger .debug (type (ret ))
553- return _fix_csv_types_generator (dfs = ret , parse_dates = parse_dates , binaries = binaries )
571+ if chunksize is None :
572+ df = _fix_csv_types (df = ret , parse_dates = parse_dates , binaries = binaries )
573+ if keep_files is False :
574+ s3 .delete_objects (path = [path , f"{ path } .metadata" ], use_threads = use_threads , boto3_session = session )
575+ return df
576+ dfs = _fix_csv_types_generator (dfs = ret , parse_dates = parse_dates , binaries = binaries )
577+ if keep_files is False :
578+ return _delete_after_iterate (
579+ dfs = dfs , paths = [path , f"{ path } .metadata" ], use_threads = use_threads , boto3_session = session
580+ )
581+ return dfs
582+
583+
584+ def _delete_after_iterate (
585+ dfs : Iterator [pd .DataFrame ], paths : List [str ], use_threads : bool , boto3_session : boto3 .Session
586+ ) -> Iterator [pd .DataFrame ]:
587+ for df in dfs :
588+ yield df
589+ s3 .delete_objects (path = paths , use_threads = use_threads , boto3_session = boto3_session )
554590
555591
556592def stop_query_execution (query_execution_id : str , boto3_session : Optional [boto3 .Session ] = None ) -> None :
@@ -638,6 +674,8 @@ def read_sql_table(
638674 workgroup : Optional [str ] = None ,
639675 encryption : Optional [str ] = None ,
640676 kms_key : Optional [str ] = None ,
677+ keep_files : bool = True ,
678+ ctas_temp_table_name : Optional [str ] = None ,
641679 use_threads : bool = True ,
642680 boto3_session : Optional [boto3 .Session ] = None ,
643681) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
@@ -712,6 +750,12 @@ def read_sql_table(
712750 None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
713751 kms_key : str, optional
714752 For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
753+ keep_files : bool
754+ Should Wrangler delete or keep the staging files produced by Athena?
755+ ctas_temp_table_name : str, optional
756+ The name of the temporary table and also the directory name on S3 where the CTAS result is stored.
757+ If None, it will use the follow random pattern: `f"temp_table_{pyarrow.compat.guid()}"`.
758+ On S3 this directory will be under under the pattern: `f"{s3_output}/{ctas_temp_table_name}/"`.
715759 use_threads : bool
716760 True to enable concurrent requests, False to disable multiple threads.
717761 If enabled os.cpu_count() will be used as the max number of threads.
@@ -740,6 +784,8 @@ def read_sql_table(
740784 workgroup = workgroup ,
741785 encryption = encryption ,
742786 kms_key = kms_key ,
787+ keep_files = keep_files ,
788+ ctas_temp_table_name = ctas_temp_table_name ,
743789 use_threads = use_threads ,
744790 boto3_session = boto3_session ,
745791 )
0 commit comments