55from math import floor
66import copy
77import csv
8- from datetime import datetime
8+ from datetime import datetime , date
99from decimal import Decimal
1010from ast import literal_eval
1111
1818
1919from awswrangler import data_types
2020from awswrangler .exceptions import (UnsupportedWriteMode , UnsupportedFileFormat , AthenaQueryError , EmptyS3Object ,
21- LineTerminatorNotFound , EmptyDataframe , InvalidSerDe , InvalidCompression )
21+ LineTerminatorNotFound , EmptyDataframe , InvalidSerDe , InvalidCompression ,
22+ InvalidParameters )
2223from awswrangler .utils import calculate_bounders
2324from awswrangler import s3
2425from awswrangler .athena import Athena
@@ -495,29 +496,100 @@ def read_sql_athena(self,
495496 sql : str ,
496497 database : Optional [str ] = None ,
497498 s3_output : Optional [str ] = None ,
498- max_result_size : Optional [int ] = None ,
499499 workgroup : Optional [str ] = None ,
500500 encryption : Optional [str ] = None ,
501- kms_key : Optional [str ] = None ):
501+ kms_key : Optional [str ] = None ,
502+ ctas_approach : bool = True ,
503+ procs_cpu_bound : Optional [int ] = None ,
504+ max_result_size : Optional [int ] = None ):
502505 """
503506 Executes any SQL query on AWS Athena and return a Dataframe of the result.
504- P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
507+ There are two approaches to be defined through ctas_approach parameter:
508+ 1 - ctas_approach True (Default):
509+ Wrap the query with a CTAS and then reads the table data as parquet directly from s3.
510+ PROS: Faster and has a better handle of nested types
511+ CONS: Can't use max_result_size.
512+ 2 - ctas_approach False:
513+ Does a regular query on Athena and parse the regular CSV result on s3
514+ PROS: Accepts max_result_size.
515+ CONS: Slower (But stills faster than other libraries that uses the Athena API) and does not handle nested types so well
516+
517+ P.S. If ctas_approach is False and max_result_size is passed, then a iterator of Dataframes is returned.
505518 P.S.S. All default values will be inherited from the Session()
506519
507520 :param sql: SQL Query
508521 :param database: Glue/Athena Database
509522 :param s3_output: AWS S3 path
510- :param max_result_size: Max number of bytes on each request to S3
511523 :param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup)
512524 :param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
513525 :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
514- :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
526+ :param ctas_approach: Wraps the query with a CTAS
527+ :param procs_cpu_bound: Number of cores used for CPU bound tasks
528+ :param max_result_size: Max number of bytes on each request to S3 (VALID ONLY FOR ctas_approach=False)
529+ :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size was passed
515530 """
531+ if ctas_approach is True and max_result_size is not None :
532+ raise InvalidParameters ("ctas_approach can't use max_result_size!" )
516533 if s3_output is None :
517534 if self ._session .athena_s3_output is not None :
518535 s3_output = self ._session .athena_s3_output
519536 else :
520537 s3_output = self ._session .athena .create_athena_bucket ()
538+ if ctas_approach is False :
539+ return self ._read_sql_athena_regular (sql = sql ,
540+ database = database ,
541+ s3_output = s3_output ,
542+ workgroup = workgroup ,
543+ encryption = encryption ,
544+ kms_key = kms_key ,
545+ max_result_size = max_result_size )
546+ else :
547+ return self ._read_sql_athena_ctas (sql = sql ,
548+ database = database ,
549+ s3_output = s3_output ,
550+ workgroup = workgroup ,
551+ encryption = encryption ,
552+ kms_key = kms_key ,
553+ procs_cpu_bound = procs_cpu_bound )
554+
555+ def _read_sql_athena_ctas (self ,
556+ sql : str ,
557+ s3_output : str ,
558+ database : Optional [str ] = None ,
559+ workgroup : Optional [str ] = None ,
560+ encryption : Optional [str ] = None ,
561+ kms_key : Optional [str ] = None ,
562+ procs_cpu_bound : Optional [int ] = None ) -> pd .DataFrame :
563+ guid : str = pa .compat .guid ()
564+ name : str = f"temp_table_{ guid } "
565+ s3_output = s3_output [:- 1 ] if s3_output [- 1 ] == "/" else s3_output
566+ path : str = f"{ s3_output } /{ name } "
567+ query : str = f"CREATE TABLE { name } \n " \
568+ f"WITH(\n " \
569+ f" format = 'Parquet',\n " \
570+ f" parquet_compression = 'SNAPPY',\n " \
571+ f" external_location = '{ path } '\n " \
572+ f") AS\n " \
573+ f"{ sql } "
574+ logger .debug (f"query: { query } " )
575+ query_id : str = self ._session .athena .run_query (query = query ,
576+ database = database ,
577+ s3_output = s3_output ,
578+ workgroup = workgroup ,
579+ encryption = encryption ,
580+ kms_key = kms_key )
581+ self ._session .athena .wait_query (query_execution_id = query_id )
582+ self ._session .glue .delete_table_if_exists (database = database , table = name )
583+ return self .read_parquet (path = path , procs_cpu_bound = procs_cpu_bound )
584+
585+ def _read_sql_athena_regular (self ,
586+ sql : str ,
587+ s3_output : str ,
588+ database : Optional [str ] = None ,
589+ workgroup : Optional [str ] = None ,
590+ encryption : Optional [str ] = None ,
591+ kms_key : Optional [str ] = None ,
592+ max_result_size : Optional [int ] = None ):
521593 query_execution_id : str = self ._session .athena .run_query (query = sql ,
522594 database = database ,
523595 s3_output = s3_output ,
@@ -542,7 +614,10 @@ def read_sql_athena(self,
542614 if max_result_size is None :
543615 if len (ret .index ) > 0 :
544616 for col in parse_dates :
545- ret [col ] = ret [col ].dt .date .replace (to_replace = {pd .NaT : None })
617+ if str (ret [col ].dtype ) == "object" :
618+ ret [col ] = ret [col ].apply (lambda x : date (* [int (y ) for y in x .split ("-" )]))
619+ else :
620+ ret [col ] = ret [col ].dt .date .replace (to_replace = {pd .NaT : None })
546621 return ret
547622 else :
548623 return Pandas ._apply_dates_to_generator (generator = ret , parse_dates = parse_dates )
@@ -1151,5 +1226,29 @@ def read_parquet(self,
11511226 use_threads : bool = True if procs_cpu_bound > 1 else False
11521227 fs : S3FileSystem = s3 .get_fs (session_primitives = self ._session .primitives )
11531228 fs = pa .filesystem ._ensure_filesystem (fs )
1154- return pq .read_table (source = path , columns = columns , filters = filters ,
1155- filesystem = fs ).to_pandas (use_threads = use_threads )
1229+ table = pq .read_table (source = path , columns = columns , filters = filters , filesystem = fs , use_threads = use_threads )
1230+ # Check if we lose some integer during the conversion (Happens when has some null value)
1231+ integers = [field .name for field in table .schema if str (field .type ).startswith ("int" )]
1232+ df = table .to_pandas (use_threads = use_threads , integer_object_nulls = True )
1233+ for c in integers :
1234+ if not str (df [c ].dtype ).startswith ("int" ):
1235+ df [c ] = df [c ].astype ("Int64" )
1236+ return df
1237+
1238+ def read_table (self ,
1239+ database : str ,
1240+ table : str ,
1241+ columns : Optional [List [str ]] = None ,
1242+ filters : Optional [Union [List [Tuple [Any ]], List [Tuple [Any ]]]] = None ,
1243+ procs_cpu_bound : Optional [int ] = None ) -> pd .DataFrame :
1244+ """
1245+ Read PARQUET table from S3 using the Glue Catalog location skipping Athena's necessity
1246+
1247+ :param database: Database name
1248+ :param table: table name
1249+ :param columns: Names of columns to read from the file
1250+ :param filters: List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
1251+ :param procs_cpu_bound: Number of cores used for CPU bound tasks
1252+ """
1253+ path : str = self ._session .glue .get_table_location (database = database , table = table )
1254+ return self .read_parquet (path = path , columns = columns , filters = filters , procs_cpu_bound = procs_cpu_bound )
0 commit comments